risingwave_frontend/optimizer/plan_expr_visitor/
strong.rs1use fixedbitset::FixedBitSet;
16
17use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
18
19#[derive(Default)]
38pub struct Strong {
39 null_columns: FixedBitSet,
40}
41
42impl Strong {
43 fn new(null_columns: FixedBitSet) -> Self {
44 Self { null_columns }
45 }
46
47 pub fn is_null(expr: &ExprImpl, null_columns: FixedBitSet) -> bool {
51 let strong = Strong::new(null_columns);
52 strong.is_null_visit(expr)
53 }
54
55 fn is_input_ref_null(&self, input_ref: &InputRef) -> bool {
56 self.null_columns.contains(input_ref.index())
57 }
58
59 fn is_null_visit(&self, expr: &ExprImpl) -> bool {
60 match expr {
61 ExprImpl::InputRef(input_ref) => self.is_input_ref_null(input_ref),
62 ExprImpl::Literal(literal) => literal.get_data().is_none(),
63 ExprImpl::FunctionCall(func_call) => self.is_null_function_call(func_call),
64 ExprImpl::FunctionCallWithLambda(_) => false,
65 ExprImpl::AggCall(_) => false,
66 ExprImpl::Subquery(_) => false,
67 ExprImpl::CorrelatedInputRef(_) => false,
68 ExprImpl::TableFunction(_) => false,
69 ExprImpl::WindowFunction(_) => false,
70 ExprImpl::UserDefinedFunction(_) => false,
71 ExprImpl::Parameter(_) => false,
72 ExprImpl::Now(_) => false,
73 }
74 }
75
76 fn is_null_function_call(&self, func_call: &FunctionCall) -> bool {
77 match func_call.func_type() {
78 ExprType::IsNull
80 | ExprType::IsNotNull
81 | ExprType::IsDistinctFrom
82 | ExprType::IsNotDistinctFrom
83 | ExprType::IsTrue
84 | ExprType::QuoteNullable
85 | ExprType::IsNotTrue
86 | ExprType::IsFalse
87 | ExprType::IsNotFalse
88 | ExprType::CheckNotNull => false,
89 ExprType::Not
91 | ExprType::Equal
92 | ExprType::NotEqual
93 | ExprType::LessThan
94 | ExprType::LessThanOrEqual
95 | ExprType::GreaterThan
96 | ExprType::GreaterThanOrEqual
97 | ExprType::Like
98 | ExprType::Add
99 | ExprType::AddWithTimeZone
100 | ExprType::Subtract
101 | ExprType::Multiply
102 | ExprType::Modulus
103 | ExprType::Divide
104 | ExprType::Cast
105 | ExprType::Trim
106 | ExprType::Ltrim
107 | ExprType::Rtrim
108 | ExprType::Ceil
109 | ExprType::Floor
110 | ExprType::Extract
111 | ExprType::L2Distance
112 | ExprType::Greatest
113 | ExprType::Least => self.any_null(func_call),
114 ExprType::And | ExprType::Or | ExprType::Coalesce => self.all_null(func_call),
116 ExprType::In
119 | ExprType::Some
120 | ExprType::All
121 | ExprType::BitwiseAnd
122 | ExprType::BitwiseOr
123 | ExprType::BitwiseXor
124 | ExprType::BitwiseNot
125 | ExprType::BitwiseShiftLeft
126 | ExprType::BitwiseShiftRight
127 | ExprType::DatePart
128 | ExprType::TumbleStart
129 | ExprType::MakeDate
130 | ExprType::MakeTime
131 | ExprType::MakeTimestamp
132 | ExprType::SecToTimestamptz
133 | ExprType::AtTimeZone
134 | ExprType::DateTrunc
135 | ExprType::DateBin
136 | ExprType::CharToTimestamptz
137 | ExprType::CharToDate
138 | ExprType::CastWithTimeZone
139 | ExprType::SubtractWithTimeZone
140 | ExprType::MakeTimestamptz
141 | ExprType::Substr
142 | ExprType::Length
143 | ExprType::ILike
144 | ExprType::SimilarToEscape
145 | ExprType::Upper
146 | ExprType::Lower
147 | ExprType::Replace
148 | ExprType::Position
149 | ExprType::Case
150 | ExprType::ConstantLookup
151 | ExprType::RoundDigit
152 | ExprType::Round
153 | ExprType::Ascii
154 | ExprType::Translate
155 | ExprType::Concat
156 | ExprType::ConcatVariadic
157 | ExprType::ConcatWs
158 | ExprType::ConcatWsVariadic
159 | ExprType::Abs
160 | ExprType::SplitPart
161 | ExprType::ToChar
162 | ExprType::Md5
163 | ExprType::CharLength
164 | ExprType::Repeat
165 | ExprType::ConcatOp
166 | ExprType::ByteaConcatOp
167 | ExprType::BoolOut
168 | ExprType::OctetLength
169 | ExprType::BitLength
170 | ExprType::Overlay
171 | ExprType::RegexpMatch
172 | ExprType::RegexpReplace
173 | ExprType::RegexpCount
174 | ExprType::RegexpSplitToArray
175 | ExprType::RegexpEq
176 | ExprType::Pow
177 | ExprType::Exp
178 | ExprType::Chr
179 | ExprType::StartsWith
180 | ExprType::Initcap
181 | ExprType::Lpad
182 | ExprType::Rpad
183 | ExprType::Reverse
184 | ExprType::Strpos
185 | ExprType::ToAscii
186 | ExprType::ToHex
187 | ExprType::QuoteIdent
188 | ExprType::QuoteLiteral
189 | ExprType::Sin
190 | ExprType::Cos
191 | ExprType::Tan
192 | ExprType::Cot
193 | ExprType::Asin
194 | ExprType::Acos
195 | ExprType::Acosd
196 | ExprType::Atan
197 | ExprType::Atan2
198 | ExprType::Atand
199 | ExprType::Atan2d
200 | ExprType::Sind
201 | ExprType::Cosd
202 | ExprType::Cotd
203 | ExprType::Tand
204 | ExprType::Asind
205 | ExprType::Sqrt
206 | ExprType::Degrees
207 | ExprType::Radians
208 | ExprType::Cosh
209 | ExprType::Tanh
210 | ExprType::Coth
211 | ExprType::Asinh
212 | ExprType::Acosh
213 | ExprType::Atanh
214 | ExprType::Sinh
215 | ExprType::Trunc
216 | ExprType::Ln
217 | ExprType::Log10
218 | ExprType::Cbrt
219 | ExprType::Sign
220 | ExprType::Scale
221 | ExprType::MinScale
222 | ExprType::TrimScale
223 | ExprType::Encode
224 | ExprType::Decode
225 | ExprType::Sha1
226 | ExprType::Sha224
227 | ExprType::Sha256
228 | ExprType::Sha384
229 | ExprType::Sha512
230 | ExprType::Hmac
231 | ExprType::SecureCompare
232 | ExprType::Left
233 | ExprType::Right
234 | ExprType::Format
235 | ExprType::FormatVariadic
236 | ExprType::PgwireSend
237 | ExprType::PgwireRecv
238 | ExprType::ConvertFrom
239 | ExprType::ConvertTo
240 | ExprType::Decrypt
241 | ExprType::Encrypt
242 | ExprType::Neg
243 | ExprType::Field
244 | ExprType::Array
245 | ExprType::ArrayAccess
246 | ExprType::Row
247 | ExprType::ArrayToString
248 | ExprType::ArrayRangeAccess
249 | ExprType::ArrayCat
250 | ExprType::ArrayAppend
251 | ExprType::ArrayPrepend
252 | ExprType::FormatType
253 | ExprType::ArrayDistinct
254 | ExprType::ArrayLength
255 | ExprType::Cardinality
256 | ExprType::ArrayRemove
257 | ExprType::ArrayPositions
258 | ExprType::TrimArray
259 | ExprType::StringToArray
260 | ExprType::ArrayPosition
261 | ExprType::ArrayReplace
262 | ExprType::ArrayDims
263 | ExprType::ArrayTransform
264 | ExprType::ArrayMin
265 | ExprType::ArrayMax
266 | ExprType::ArraySum
267 | ExprType::ArraySort
268 | ExprType::ArrayContains
269 | ExprType::ArrayContained
270 | ExprType::ArrayFlatten
271 | ExprType::HexToInt256
272 | ExprType::JsonbAccess
273 | ExprType::JsonbAccessStr
274 | ExprType::JsonbExtractPath
275 | ExprType::JsonbExtractPathVariadic
276 | ExprType::JsonbExtractPathText
277 | ExprType::JsonbExtractPathTextVariadic
278 | ExprType::JsonbTypeof
279 | ExprType::JsonbArrayLength
280 | ExprType::IsJson
281 | ExprType::JsonbConcat
282 | ExprType::JsonbObject
283 | ExprType::JsonbPretty
284 | ExprType::JsonbContains
285 | ExprType::JsonbContained
286 | ExprType::JsonbExists
287 | ExprType::JsonbExistsAny
288 | ExprType::JsonbExistsAll
289 | ExprType::JsonbDeletePath
290 | ExprType::JsonbStripNulls
291 | ExprType::ToJsonb
292 | ExprType::JsonbBuildArray
293 | ExprType::JsonbBuildArrayVariadic
294 | ExprType::JsonbBuildObject
295 | ExprType::JsonbBuildObjectVariadic
296 | ExprType::JsonbPathExists
297 | ExprType::JsonbPathMatch
298 | ExprType::JsonbPathQueryArray
299 | ExprType::JsonbPathQueryFirst
300 | ExprType::JsonbPopulateRecord
301 | ExprType::JsonbToRecord
302 | ExprType::JsonbSet
303 | ExprType::JsonbPopulateMap
304 | ExprType::MapFromEntries
305 | ExprType::MapAccess
306 | ExprType::MapKeys
307 | ExprType::MapValues
308 | ExprType::MapEntries
309 | ExprType::MapFromKeyValues
310 | ExprType::MapCat
311 | ExprType::MapContains
312 | ExprType::MapDelete
313 | ExprType::MapFilter
314 | ExprType::MapInsert
315 | ExprType::MapLength
316 | ExprType::Vnode
317 | ExprType::VnodeUser
318 | ExprType::TestPaidTier
319 | ExprType::License
320 | ExprType::Proctime
321 | ExprType::PgSleep
322 | ExprType::PgSleepFor
323 | ExprType::PgSleepUntil
324 | ExprType::CastRegclass
325 | ExprType::PgGetIndexdef
326 | ExprType::ColDescription
327 | ExprType::PgGetViewdef
328 | ExprType::PgGetUserbyid
329 | ExprType::PgIndexesSize
330 | ExprType::PgRelationSize
331 | ExprType::PgGetSerialSequence
332 | ExprType::PgIndexColumnHasProperty
333 | ExprType::PgIsInRecovery
334 | ExprType::PgTableIsVisible
335 | ExprType::RwRecoveryStatus
336 | ExprType::IcebergTransform
337 | ExprType::HasTablePrivilege
338 | ExprType::HasFunctionPrivilege
339 | ExprType::HasAnyColumnPrivilege
340 | ExprType::HasSchemaPrivilege
341 | ExprType::InetAton
342 | ExprType::InetNtoa
343 | ExprType::CompositeCast
344 | ExprType::RwEpochToTs
345 | ExprType::OpenaiEmbedding => false,
346 ExprType::Unspecified => unreachable!(),
347 }
348 }
349
350 fn any_null(&self, func_call: &FunctionCall) -> bool {
351 func_call
352 .inputs()
353 .iter()
354 .any(|expr| self.is_null_visit(expr))
355 }
356
357 fn all_null(&self, func_call: &FunctionCall) -> bool {
358 func_call
359 .inputs()
360 .iter()
361 .all(|expr| self.is_null_visit(expr))
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use risingwave_common::types::DataType;
368
369 use super::*;
370 use crate::expr::ExprImpl::Literal;
371
372 #[test]
373 fn test_literal() {
374 let null_columns = FixedBitSet::with_capacity(1);
375 let expr = Literal(crate::expr::Literal::new(None, DataType::Varchar).into());
376 assert!(Strong::is_null(&expr, null_columns.clone()));
377
378 let expr = Literal(
379 crate::expr::Literal::new(Some("test".to_owned().into()), DataType::Varchar).into(),
380 );
381 assert!(!Strong::is_null(&expr, null_columns));
382 }
383
384 #[test]
385 fn test_input_ref1() {
386 let null_columns = FixedBitSet::with_capacity(2);
387 let expr = InputRef::new(0, DataType::Varchar).into();
388 assert!(!Strong::is_null(&expr, null_columns.clone()));
389
390 let expr = InputRef::new(1, DataType::Varchar).into();
391 assert!(!Strong::is_null(&expr, null_columns));
392 }
393
394 #[test]
395 fn test_input_ref2() {
396 let mut null_columns = FixedBitSet::with_capacity(2);
397 null_columns.insert(0);
398 null_columns.insert(1);
399 let expr = InputRef::new(0, DataType::Varchar).into();
400 assert!(Strong::is_null(&expr, null_columns.clone()));
401
402 let expr = InputRef::new(1, DataType::Varchar).into();
403 assert!(Strong::is_null(&expr, null_columns));
404 }
405
406 #[test]
407 fn test_c1_equal_1_or_c2_is_null() {
408 let mut null_columns = FixedBitSet::with_capacity(2);
409 null_columns.insert(0);
410 let expr = FunctionCall::new_unchecked(
411 ExprType::Or,
412 vec![
413 FunctionCall::new_unchecked(
414 ExprType::Equal,
415 vec![
416 InputRef::new(0, DataType::Int64).into(),
417 Literal(crate::expr::Literal::new(Some(1.into()), DataType::Int32).into()),
418 ],
419 DataType::Boolean,
420 )
421 .into(),
422 FunctionCall::new_unchecked(
423 ExprType::IsNull,
424 vec![InputRef::new(1, DataType::Int64).into()],
425 DataType::Boolean,
426 )
427 .into(),
428 ],
429 DataType::Boolean,
430 )
431 .into();
432 assert!(!Strong::is_null(&expr, null_columns));
433 }
434
435 #[test]
436 fn test_divide() {
437 let mut null_columns = FixedBitSet::with_capacity(2);
438 null_columns.insert(0);
439 null_columns.insert(1);
440 let expr = FunctionCall::new_unchecked(
441 ExprType::Divide,
442 vec![
443 InputRef::new(0, DataType::Decimal).into(),
444 InputRef::new(1, DataType::Decimal).into(),
445 ],
446 DataType::Varchar,
447 )
448 .into();
449 assert!(Strong::is_null(&expr, null_columns));
450 }
451
452 #[test]
454 fn test_multiply_divide() {
455 let mut null_columns = FixedBitSet::with_capacity(2);
456 null_columns.insert(0);
457 let expr = FunctionCall::new_unchecked(
458 ExprType::Multiply,
459 vec![
460 Literal(crate::expr::Literal::new(Some(0.8f64.into()), DataType::Float64).into()),
461 FunctionCall::new_unchecked(
462 ExprType::Divide,
463 vec![
464 InputRef::new(0, DataType::Decimal).into(),
465 InputRef::new(1, DataType::Decimal).into(),
466 ],
467 DataType::Decimal,
468 )
469 .into(),
470 ],
471 DataType::Decimal,
472 )
473 .into();
474 assert!(Strong::is_null(&expr, null_columns));
475 }
476
477 macro_rules! gen_test {
479 ($func:ident, $expr:expr, $expected:expr) => {
480 #[test]
481 fn $func() {
482 let null_columns = FixedBitSet::with_capacity(2);
483 let expr = $expr;
484 assert_eq!(Strong::is_null(&expr, null_columns), $expected);
485 }
486 };
487 }
488
489 gen_test!(
490 test_is_not_null,
491 FunctionCall::new_unchecked(
492 ExprType::IsNotNull,
493 vec![InputRef::new(0, DataType::Varchar).into()],
494 DataType::Varchar
495 )
496 .into(),
497 false
498 );
499 gen_test!(
500 test_is_null,
501 FunctionCall::new_unchecked(
502 ExprType::IsNull,
503 vec![InputRef::new(0, DataType::Varchar).into()],
504 DataType::Varchar
505 )
506 .into(),
507 false
508 );
509 gen_test!(
510 test_is_distinct_from,
511 FunctionCall::new_unchecked(
512 ExprType::IsDistinctFrom,
513 vec![
514 InputRef::new(0, DataType::Varchar).into(),
515 InputRef::new(1, DataType::Varchar).into()
516 ],
517 DataType::Varchar
518 )
519 .into(),
520 false
521 );
522 gen_test!(
523 test_is_not_distinct_from,
524 FunctionCall::new_unchecked(
525 ExprType::IsNotDistinctFrom,
526 vec![
527 InputRef::new(0, DataType::Varchar).into(),
528 InputRef::new(1, DataType::Varchar).into()
529 ],
530 DataType::Varchar
531 )
532 .into(),
533 false
534 );
535 gen_test!(
536 test_is_true,
537 FunctionCall::new_unchecked(
538 ExprType::IsTrue,
539 vec![InputRef::new(0, DataType::Varchar).into()],
540 DataType::Varchar
541 )
542 .into(),
543 false
544 );
545 gen_test!(
546 test_is_not_true,
547 FunctionCall::new_unchecked(
548 ExprType::IsNotTrue,
549 vec![InputRef::new(0, DataType::Varchar).into()],
550 DataType::Varchar
551 )
552 .into(),
553 false
554 );
555 gen_test!(
556 test_is_false,
557 FunctionCall::new_unchecked(
558 ExprType::IsFalse,
559 vec![InputRef::new(0, DataType::Varchar).into()],
560 DataType::Varchar
561 )
562 .into(),
563 false
564 );
565}