1use std::fmt::Debug;
16
17use constant_time_eq::constant_time_eq;
18use risingwave_common::array::{Array, BoolArray};
19use risingwave_common::bitmap::Bitmap;
20use risingwave_common::row::Row;
21use risingwave_common::types::{Scalar, ScalarRef, ScalarRefImpl};
22use risingwave_expr::{ExprError, Result, function};
23
24#[function("equal(boolean, boolean) -> boolean", batch_fn = "boolarray_eq")]
25#[function("equal(*int, *int) -> boolean")]
26#[function("equal(decimal, decimal) -> boolean")]
27#[function("equal(*float, *float) -> boolean")]
28#[function("equal(int256, int256) -> boolean")]
29#[function("equal(serial, serial) -> boolean")]
30#[function("equal(date, date) -> boolean")]
31#[function("equal(time, time) -> boolean")]
32#[function("equal(interval, interval) -> boolean")]
33#[function("equal(timestamp, timestamp) -> boolean")]
34#[function("equal(timestamptz, timestamptz) -> boolean")]
35#[function("equal(date, timestamp) -> boolean")]
36#[function("equal(timestamp, date) -> boolean")]
37#[function("equal(time, interval) -> boolean")]
38#[function("equal(interval, time) -> boolean")]
39#[function("equal(varchar, varchar) -> boolean")]
40#[function("equal(bytea, bytea) -> boolean")]
41#[function("equal(jsonb, jsonb) -> boolean")]
42#[function("equal(anyarray, anyarray) -> boolean")]
43#[function("equal(struct, struct) -> boolean")]
44pub fn general_eq<T1, T2, T3>(l: T1, r: T2) -> bool
45where
46 T1: Into<T3> + Debug,
47 T2: Into<T3> + Debug,
48 T3: Ord,
49{
50 l.into() == r.into()
51}
52
53#[function("not_equal(boolean, boolean) -> boolean", batch_fn = "boolarray_ne")]
54#[function("not_equal(*int, *int) -> boolean")]
55#[function("not_equal(decimal, decimal) -> boolean")]
56#[function("not_equal(*float, *float) -> boolean")]
57#[function("not_equal(int256, int256) -> boolean")]
58#[function("not_equal(serial, serial) -> boolean")]
59#[function("not_equal(date, date) -> boolean")]
60#[function("not_equal(time, time) -> boolean")]
61#[function("not_equal(interval, interval) -> boolean")]
62#[function("not_equal(timestamp, timestamp) -> boolean")]
63#[function("not_equal(timestamptz, timestamptz) -> boolean")]
64#[function("not_equal(date, timestamp) -> boolean")]
65#[function("not_equal(timestamp, date) -> boolean")]
66#[function("not_equal(time, interval) -> boolean")]
67#[function("not_equal(interval, time) -> boolean")]
68#[function("not_equal(varchar, varchar) -> boolean")]
69#[function("not_equal(bytea, bytea) -> boolean")]
70#[function("not_equal(jsonb, jsonb) -> boolean")]
71#[function("not_equal(anyarray, anyarray) -> boolean")]
72#[function("not_equal(struct, struct) -> boolean")]
73pub fn general_ne<T1, T2, T3>(l: T1, r: T2) -> bool
74where
75 T1: Into<T3> + Debug,
76 T2: Into<T3> + Debug,
77 T3: Ord,
78{
79 l.into() != r.into()
80}
81
82#[function(
83 "greater_than_or_equal(boolean, boolean) -> boolean",
84 batch_fn = "boolarray_ge"
85)]
86#[function("greater_than_or_equal(*int, *int) -> boolean")]
87#[function("greater_than_or_equal(decimal, decimal) -> boolean")]
88#[function("greater_than_or_equal(*float, *float) -> boolean")]
89#[function("greater_than_or_equal(serial, serial) -> boolean")]
90#[function("greater_than_or_equal(int256, int256) -> boolean")]
91#[function("greater_than_or_equal(date, date) -> boolean")]
92#[function("greater_than_or_equal(time, time) -> boolean")]
93#[function("greater_than_or_equal(interval, interval) -> boolean")]
94#[function("greater_than_or_equal(timestamp, timestamp) -> boolean")]
95#[function("greater_than_or_equal(timestamptz, timestamptz) -> boolean")]
96#[function("greater_than_or_equal(date, timestamp) -> boolean")]
97#[function("greater_than_or_equal(timestamp, date) -> boolean")]
98#[function("greater_than_or_equal(time, interval) -> boolean")]
99#[function("greater_than_or_equal(interval, time) -> boolean")]
100#[function("greater_than_or_equal(varchar, varchar) -> boolean")]
101#[function("greater_than_or_equal(bytea, bytea) -> boolean")]
102#[function("greater_than_or_equal(anyarray, anyarray) -> boolean")]
103#[function("greater_than_or_equal(struct, struct) -> boolean")]
104pub fn general_ge<T1, T2, T3>(l: T1, r: T2) -> bool
105where
106 T1: Into<T3> + Debug,
107 T2: Into<T3> + Debug,
108 T3: Ord,
109{
110 l.into() >= r.into()
111}
112
113#[function("greater_than(boolean, boolean) -> boolean", batch_fn = "boolarray_gt")]
114#[function("greater_than(*int, *int) -> boolean")]
115#[function("greater_than(decimal, decimal) -> boolean")]
116#[function("greater_than(*float, *float) -> boolean")]
117#[function("greater_than(serial, serial) -> boolean")]
118#[function("greater_than(int256, int256) -> boolean")]
119#[function("greater_than(date, date) -> boolean")]
120#[function("greater_than(time, time) -> boolean")]
121#[function("greater_than(interval, interval) -> boolean")]
122#[function("greater_than(timestamp, timestamp) -> boolean")]
123#[function("greater_than(timestamptz, timestamptz) -> boolean")]
124#[function("greater_than(date, timestamp) -> boolean")]
125#[function("greater_than(timestamp, date) -> boolean")]
126#[function("greater_than(time, interval) -> boolean")]
127#[function("greater_than(interval, time) -> boolean")]
128#[function("greater_than(varchar, varchar) -> boolean")]
129#[function("greater_than(bytea, bytea) -> boolean")]
130#[function("greater_than(anyarray, anyarray) -> boolean")]
131#[function("greater_than(struct, struct) -> boolean")]
132pub fn general_gt<T1, T2, T3>(l: T1, r: T2) -> bool
133where
134 T1: Into<T3> + Debug,
135 T2: Into<T3> + Debug,
136 T3: Ord,
137{
138 l.into() > r.into()
139}
140
141#[function(
142 "less_than_or_equal(boolean, boolean) -> boolean",
143 batch_fn = "boolarray_le"
144)]
145#[function("less_than_or_equal(*int, *int) -> boolean")]
146#[function("less_than_or_equal(decimal, decimal) -> boolean")]
147#[function("less_than_or_equal(*float, *float) -> boolean")]
148#[function("less_than_or_equal(serial, serial) -> boolean")]
149#[function("less_than_or_equal(int256, int256) -> boolean")]
150#[function("less_than_or_equal(date, date) -> boolean")]
151#[function("less_than_or_equal(time, time) -> boolean")]
152#[function("less_than_or_equal(interval, interval) -> boolean")]
153#[function("less_than_or_equal(timestamp, timestamp) -> boolean")]
154#[function("less_than_or_equal(timestamptz, timestamptz) -> boolean")]
155#[function("less_than_or_equal(date, timestamp) -> boolean")]
156#[function("less_than_or_equal(timestamp, date) -> boolean")]
157#[function("less_than_or_equal(time, interval) -> boolean")]
158#[function("less_than_or_equal(interval, time) -> boolean")]
159#[function("less_than_or_equal(varchar, varchar) -> boolean")]
160#[function("less_than_or_equal(bytea, bytea) -> boolean")]
161#[function("less_than_or_equal(anyarray, anyarray) -> boolean")]
162#[function("less_than_or_equal(struct, struct) -> boolean")]
163pub fn general_le<T1, T2, T3>(l: T1, r: T2) -> bool
164where
165 T1: Into<T3> + Debug,
166 T2: Into<T3> + Debug,
167 T3: Ord,
168{
169 l.into() <= r.into()
170}
171
172#[function("less_than(boolean, boolean) -> boolean", batch_fn = "boolarray_lt")]
173#[function("less_than(*int, *int) -> boolean")]
174#[function("less_than(decimal, decimal) -> boolean")]
175#[function("less_than(*float, *float) -> boolean")]
176#[function("less_than(serial, serial) -> boolean")]
177#[function("less_than(int256, int256) -> boolean")]
178#[function("less_than(date, date) -> boolean")]
179#[function("less_than(time, time) -> boolean")]
180#[function("less_than(interval, interval) -> boolean")]
181#[function("less_than(timestamp, timestamp) -> boolean")]
182#[function("less_than(timestamptz, timestamptz) -> boolean")]
183#[function("less_than(date, timestamp) -> boolean")]
184#[function("less_than(timestamp, date) -> boolean")]
185#[function("less_than(time, interval) -> boolean")]
186#[function("less_than(interval, time) -> boolean")]
187#[function("less_than(varchar, varchar) -> boolean")]
188#[function("less_than(bytea, bytea) -> boolean")]
189#[function("less_than(anyarray, anyarray) -> boolean")]
190#[function("less_than(struct, struct) -> boolean")]
191pub fn general_lt<T1, T2, T3>(l: T1, r: T2) -> bool
192where
193 T1: Into<T3> + Debug,
194 T2: Into<T3> + Debug,
195 T3: Ord,
196{
197 l.into() < r.into()
198}
199
200#[function(
201 "is_distinct_from(boolean, boolean) -> boolean",
202 batch_fn = "boolarray_is_distinct_from"
203)]
204#[function("is_distinct_from(*int, *int) -> boolean")]
205#[function("is_distinct_from(decimal, decimal) -> boolean")]
206#[function("is_distinct_from(*float, *float) -> boolean")]
207#[function("is_distinct_from(serial, serial) -> boolean")]
208#[function("is_distinct_from(int256, int256) -> boolean")]
209#[function("is_distinct_from(date, date) -> boolean")]
210#[function("is_distinct_from(time, time) -> boolean")]
211#[function("is_distinct_from(interval, interval) -> boolean")]
212#[function("is_distinct_from(timestamp, timestamp) -> boolean")]
213#[function("is_distinct_from(timestamptz, timestamptz) -> boolean")]
214#[function("is_distinct_from(date, timestamp) -> boolean")]
215#[function("is_distinct_from(timestamp, date) -> boolean")]
216#[function("is_distinct_from(time, interval) -> boolean")]
217#[function("is_distinct_from(interval, time) -> boolean")]
218#[function("is_distinct_from(varchar, varchar) -> boolean")]
219#[function("is_distinct_from(bytea, bytea) -> boolean")]
220#[function("is_distinct_from(anyarray, anyarray) -> boolean")]
221#[function("is_distinct_from(struct, struct) -> boolean")]
222pub fn general_is_distinct_from<T1, T2, T3>(l: Option<T1>, r: Option<T2>) -> bool
223where
224 T1: Into<T3> + Debug,
225 T2: Into<T3> + Debug,
226 T3: Ord,
227{
228 l.map(Into::into) != r.map(Into::into)
229}
230
231#[function(
232 "is_not_distinct_from(boolean, boolean) -> boolean",
233 batch_fn = "boolarray_is_not_distinct_from"
234)]
235#[function("is_not_distinct_from(*int, *int) -> boolean")]
236#[function("is_not_distinct_from(decimal, decimal) -> boolean")]
237#[function("is_not_distinct_from(*float, *float) -> boolean")]
238#[function("is_not_distinct_from(serial, serial) -> boolean")]
239#[function("is_not_distinct_from(int256, int256) -> boolean")]
240#[function("is_not_distinct_from(date, date) -> boolean")]
241#[function("is_not_distinct_from(time, time) -> boolean")]
242#[function("is_not_distinct_from(interval, interval) -> boolean")]
243#[function("is_not_distinct_from(timestamp, timestamp) -> boolean")]
244#[function("is_not_distinct_from(timestamptz, timestamptz) -> boolean")]
245#[function("is_not_distinct_from(date, timestamp) -> boolean")]
246#[function("is_not_distinct_from(timestamp, date) -> boolean")]
247#[function("is_not_distinct_from(time, interval) -> boolean")]
248#[function("is_not_distinct_from(interval, time) -> boolean")]
249#[function("is_not_distinct_from(varchar, varchar) -> boolean")]
250#[function("is_not_distinct_from(bytea, bytea) -> boolean")]
251#[function("is_not_distinct_from(anyarray, anyarray) -> boolean")]
252#[function("is_not_distinct_from(struct, struct) -> boolean")]
253pub fn general_is_not_distinct_from<T1, T2, T3>(l: Option<T1>, r: Option<T2>) -> bool
254where
255 T1: Into<T3> + Debug,
256 T2: Into<T3> + Debug,
257 T3: Ord,
258{
259 l.map(Into::into) == r.map(Into::into)
260}
261
262#[function("is_true(boolean) -> boolean", batch_fn = "boolarray_is_true")]
263pub fn is_true(v: Option<bool>) -> bool {
264 v == Some(true)
265}
266
267#[function("is_not_true(boolean) -> boolean", batch_fn = "boolarray_is_not_true")]
268pub fn is_not_true(v: Option<bool>) -> bool {
269 v != Some(true)
270}
271
272#[function("is_false(boolean) -> boolean", batch_fn = "boolarray_is_false")]
273pub fn is_false(v: Option<bool>) -> bool {
274 v == Some(false)
275}
276
277#[function(
278 "is_not_false(boolean) -> boolean",
279 batch_fn = "boolarray_is_not_false"
280)]
281pub fn is_not_false(v: Option<bool>) -> bool {
282 v != Some(false)
283}
284
285#[function("is_null(*) -> boolean", batch_fn = "batch_is_null")]
286fn is_null<T>(v: Option<T>) -> bool {
287 v.is_none()
288}
289
290#[function("is_not_null(*) -> boolean", batch_fn = "batch_is_not_null")]
291fn is_not_null<T>(v: Option<T>) -> bool {
292 v.is_some()
293}
294
295#[function("greatest(...) -> boolean")]
296#[function("greatest(...) -> *int")]
297#[function("greatest(...) -> decimal")]
298#[function("greatest(...) -> *float")]
299#[function("greatest(...) -> serial")]
300#[function("greatest(...) -> int256")]
301#[function("greatest(...) -> date")]
302#[function("greatest(...) -> time")]
303#[function("greatest(...) -> interval")]
304#[function("greatest(...) -> timestamp")]
305#[function("greatest(...) -> timestamptz")]
306#[function("greatest(...) -> varchar")]
307#[function("greatest(...) -> bytea")]
308pub fn general_variadic_greatest<T>(row: impl Row) -> Option<T>
309where
310 T: Scalar,
311 for<'a> <T as Scalar>::ScalarRefType<'a>: TryFrom<ScalarRefImpl<'a>> + Ord + Debug,
312{
313 row.iter()
314 .flatten()
315 .map(
316 |scalar| match <<T as Scalar>::ScalarRefType<'_>>::try_from(scalar) {
317 Ok(v) => v,
318 Err(_) => unreachable!("all input type should have been aligned in the frontend"),
319 },
320 )
321 .max()
322 .map(|v| v.to_owned_scalar())
323}
324
325#[function("least(...) -> boolean")]
326#[function("least(...) -> *int")]
327#[function("least(...) -> decimal")]
328#[function("least(...) -> *float")]
329#[function("least(...) -> serial")]
330#[function("least(...) -> int256")]
331#[function("least(...) -> date")]
332#[function("least(...) -> time")]
333#[function("least(...) -> interval")]
334#[function("least(...) -> timestamp")]
335#[function("least(...) -> timestamptz")]
336#[function("least(...) -> varchar")]
337#[function("least(...) -> bytea")]
338pub fn general_variadic_least<T>(row: impl Row) -> Option<T>
339where
340 T: Scalar,
341 for<'a> <T as Scalar>::ScalarRefType<'a>: TryFrom<ScalarRefImpl<'a>> + Ord + Debug,
342{
343 row.iter()
344 .flatten()
345 .map(
346 |scalar| match <<T as Scalar>::ScalarRefType<'_>>::try_from(scalar) {
347 Ok(v) => v,
348 Err(_) => unreachable!("all input type should have been aligned in the frontend"),
349 },
350 )
351 .min()
352 .map(|v| v.to_owned_scalar())
353}
354
355fn boolarray_eq(l: &BoolArray, r: &BoolArray) -> BoolArray {
358 let data = !(l.data() ^ r.data());
359 let bitmap = l.null_bitmap() & r.null_bitmap();
360 BoolArray::new(data, bitmap)
361}
362
363fn boolarray_ne(l: &BoolArray, r: &BoolArray) -> BoolArray {
364 let data = l.data() ^ r.data();
365 let bitmap = l.null_bitmap() & r.null_bitmap();
366 BoolArray::new(data, bitmap)
367}
368
369fn boolarray_gt(l: &BoolArray, r: &BoolArray) -> BoolArray {
370 let data = l.data() & !r.data();
371 let bitmap = l.null_bitmap() & r.null_bitmap();
372 BoolArray::new(data, bitmap)
373}
374
375fn boolarray_lt(l: &BoolArray, r: &BoolArray) -> BoolArray {
376 let data = !l.data() & r.data();
377 let bitmap = l.null_bitmap() & r.null_bitmap();
378 BoolArray::new(data, bitmap)
379}
380
381fn boolarray_ge(l: &BoolArray, r: &BoolArray) -> BoolArray {
382 let data = l.data() | !r.data();
383 let bitmap = l.null_bitmap() & r.null_bitmap();
384 BoolArray::new(data, bitmap)
385}
386
387fn boolarray_le(l: &BoolArray, r: &BoolArray) -> BoolArray {
388 let data = !l.data() | r.data();
389 let bitmap = l.null_bitmap() & r.null_bitmap();
390 BoolArray::new(data, bitmap)
391}
392
393fn boolarray_is_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray {
394 let data = ((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap()))
395 | (l.null_bitmap() ^ r.null_bitmap());
396 BoolArray::new(data, Bitmap::ones(l.len()))
397}
398
399fn boolarray_is_not_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray {
400 let data = !(((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap()))
401 | (l.null_bitmap() ^ r.null_bitmap()));
402 BoolArray::new(data, Bitmap::ones(l.len()))
403}
404
405fn boolarray_is_true(a: &BoolArray) -> BoolArray {
406 BoolArray::new(a.to_bitmap(), Bitmap::ones(a.len()))
407}
408
409fn boolarray_is_not_true(a: &BoolArray) -> BoolArray {
410 BoolArray::new(!a.to_bitmap(), Bitmap::ones(a.len()))
411}
412
413fn boolarray_is_false(a: &BoolArray) -> BoolArray {
414 BoolArray::new(!a.data() & a.null_bitmap(), Bitmap::ones(a.len()))
415}
416
417fn boolarray_is_not_false(a: &BoolArray) -> BoolArray {
418 BoolArray::new(a.data() | !a.null_bitmap(), Bitmap::ones(a.len()))
419}
420
421fn batch_is_null(a: &impl Array) -> BoolArray {
422 BoolArray::new(!a.null_bitmap(), Bitmap::ones(a.len()))
423}
424
425fn batch_is_not_null(a: &impl Array) -> BoolArray {
426 BoolArray::new(a.null_bitmap().clone(), Bitmap::ones(a.len()))
427}
428
429#[function("secure_compare(varchar, varchar) -> boolean")]
430pub fn secure_compare(left: &str, right: &str) -> bool {
431 constant_time_eq(left.as_bytes(), right.as_bytes())
432}
433
434#[function("check_not_null(any, varchar, varchar) -> any")]
435fn check_not_null<'a>(
436 v: Option<ScalarRefImpl<'a>>,
437 col_name: &str,
438 relation_name: &str,
439) -> Result<Option<ScalarRefImpl<'a>>> {
440 if v.is_none() {
441 return Err(ExprError::NotNullViolation {
442 col_name: col_name.into(),
443 table_name: relation_name.into(),
444 });
445 }
446 Ok(v)
447}
448
449#[cfg(test)]
450mod tests {
451 use std::str::FromStr;
452
453 use risingwave_common::types::{Decimal, F32, F64, Timestamp};
454 use risingwave_expr::expr::build_from_pretty;
455
456 use super::*;
457
458 #[test]
459 fn test_comparison() {
460 assert!(general_eq::<Decimal, i32, Decimal>(dec("1.0"), 1));
461 assert!(!general_ne::<Decimal, i32, Decimal>(dec("1.0"), 1));
462 assert!(!general_gt::<Decimal, i32, Decimal>(dec("1.0"), 2));
463 assert!(general_le::<Decimal, i32, Decimal>(dec("1.0"), 2));
464 assert!(!general_ge::<Decimal, i32, Decimal>(dec("1.0"), 2));
465 assert!(general_lt::<Decimal, i32, Decimal>(dec("1.0"), 2));
466 assert!(general_is_distinct_from::<Decimal, i32, Decimal>(
467 Some(dec("1.0")),
468 Some(2)
469 ));
470 assert!(general_is_distinct_from::<Decimal, i32, Decimal>(
471 None,
472 Some(1)
473 ));
474 assert!(!general_is_distinct_from::<Decimal, i32, Decimal>(
475 Some(dec("1.0")),
476 Some(1)
477 ));
478 assert!(general_eq::<F32, i32, F64>(1.0.into(), 1));
479 assert!(!general_ne::<F32, i32, F64>(1.0.into(), 1));
480 assert!(!general_lt::<F32, i32, F64>(1.0.into(), 1));
481 assert!(general_le::<F32, i32, F64>(1.0.into(), 1));
482 assert!(!general_gt::<F32, i32, F64>(1.0.into(), 1));
483 assert!(general_ge::<F32, i32, F64>(1.0.into(), 1));
484 assert!(!general_is_distinct_from::<F32, i32, F64>(
485 Some(1.0.into()),
486 Some(1)
487 ));
488 assert!(general_eq::<i64, i32, i64>(1i64, 1));
489 assert!(!general_ne::<i64, i32, i64>(1i64, 1));
490 assert!(!general_lt::<i64, i32, i64>(1i64, 1));
491 assert!(general_le::<i64, i32, i64>(1i64, 1));
492 assert!(!general_gt::<i64, i32, i64>(1i64, 1));
493 assert!(general_ge::<i64, i32, i64>(1i64, 1));
494 assert!(!general_is_distinct_from::<i64, i32, i64>(
495 Some(1i64),
496 Some(1)
497 ));
498 }
499
500 fn dec(s: &str) -> Decimal {
501 Decimal::from_str(s).unwrap()
502 }
503
504 #[tokio::test]
505 async fn test_is_distinct_from() {
506 let (input, target) = DataChunk::from_pretty(
507 "
508 i i B
509 . . f
510 . 1 t
511 1 . t
512 2 2 f
513 3 4 t
514 ",
515 )
516 .split_column_at(2);
517 let expr = build_from_pretty("(is_distinct_from:boolean $0:int4 $1:int4)");
518 let result = expr.eval(&input).await.unwrap();
519 assert_eq!(&result, target.column_at(0));
520 }
521
522 #[tokio::test]
523 async fn test_is_not_distinct_from() {
524 let (input, target) = DataChunk::from_pretty(
525 "
526 i i B
527 . . t
528 . 1 f
529 1 . f
530 2 2 t
531 3 4 f
532 ",
533 )
534 .split_column_at(2);
535 let expr = build_from_pretty("(is_not_distinct_from:boolean $0:int4 $1:int4)");
536 let result = expr.eval(&input).await.unwrap();
537 assert_eq!(&result, target.column_at(0));
538 }
539
540 use risingwave_common::array::*;
541 use risingwave_common::row::OwnedRow;
542 use risingwave_common::types::test_utils::IntervalTestExt;
543 use risingwave_common::types::{Date, Interval};
544 use risingwave_pb::expr::expr_node::Type;
545
546 use crate::scalar::arithmetic_op::{date_interval_add, date_interval_sub};
547
548 #[tokio::test]
549 async fn test_binary() {
550 test_binary_i32::<I32Array, _>(|x, y| x + y, Type::Add).await;
551 test_binary_i32::<I32Array, _>(|x, y| x - y, Type::Subtract).await;
552 test_binary_i32::<I32Array, _>(|x, y| x * y, Type::Multiply).await;
553 test_binary_i32::<I32Array, _>(|x, y| x / y, Type::Divide).await;
554 test_binary_i32::<BoolArray, _>(|x, y| x == y, Type::Equal).await;
555 test_binary_i32::<BoolArray, _>(|x, y| x != y, Type::NotEqual).await;
556 test_binary_i32::<BoolArray, _>(|x, y| x > y, Type::GreaterThan).await;
557 test_binary_i32::<BoolArray, _>(|x, y| x >= y, Type::GreaterThanOrEqual).await;
558 test_binary_i32::<BoolArray, _>(|x, y| x < y, Type::LessThan).await;
559 test_binary_i32::<BoolArray, _>(|x, y| x <= y, Type::LessThanOrEqual).await;
560 test_binary_inner::<I32Array, I32Array, I32Array, _>(
561 reduce(std::cmp::max::<i32>),
562 Type::Greatest,
563 )
564 .await;
565 test_binary_inner::<I32Array, I32Array, I32Array, _>(
566 reduce(std::cmp::min::<i32>),
567 Type::Least,
568 )
569 .await;
570 test_binary_inner::<BoolArray, BoolArray, BoolArray, _>(
571 reduce(std::cmp::max::<bool>),
572 Type::Greatest,
573 )
574 .await;
575 test_binary_inner::<BoolArray, BoolArray, BoolArray, _>(
576 reduce(std::cmp::min::<bool>),
577 Type::Least,
578 )
579 .await;
580 test_binary_decimal::<DecimalArray, _>(|x, y| x + y, Type::Add).await;
581 test_binary_decimal::<DecimalArray, _>(|x, y| x - y, Type::Subtract).await;
582 test_binary_decimal::<DecimalArray, _>(|x, y| x * y, Type::Multiply).await;
583 test_binary_decimal::<DecimalArray, _>(|x, y| x / y, Type::Divide).await;
584 test_binary_decimal::<BoolArray, _>(|x, y| x == y, Type::Equal).await;
585 test_binary_decimal::<BoolArray, _>(|x, y| x != y, Type::NotEqual).await;
586 test_binary_decimal::<BoolArray, _>(|x, y| x > y, Type::GreaterThan).await;
587 test_binary_decimal::<BoolArray, _>(|x, y| x >= y, Type::GreaterThanOrEqual).await;
588 test_binary_decimal::<BoolArray, _>(|x, y| x < y, Type::LessThan).await;
589 test_binary_decimal::<BoolArray, _>(|x, y| x <= y, Type::LessThanOrEqual).await;
590 test_binary_interval::<TimestampArray, _>(
591 |x, y| date_interval_add(x, y).unwrap(),
592 Type::Add,
593 )
594 .await;
595 test_binary_interval::<TimestampArray, _>(
596 |x, y| date_interval_sub(x, y).unwrap(),
597 Type::Subtract,
598 )
599 .await;
600 }
601
602 trait TestFrom: Copy {
603 const NAME: &'static str;
604 fn test_from(i: usize) -> Self;
605 }
606
607 impl TestFrom for i32 {
608 const NAME: &'static str = "int4";
609
610 fn test_from(i: usize) -> Self {
611 i as i32
612 }
613 }
614
615 impl TestFrom for Decimal {
616 const NAME: &'static str = "decimal";
617
618 fn test_from(i: usize) -> Self {
619 i.into()
620 }
621 }
622
623 impl TestFrom for bool {
624 const NAME: &'static str = "boolean";
625
626 fn test_from(i: usize) -> Self {
627 i % 2 == 0
628 }
629 }
630
631 impl TestFrom for Timestamp {
632 const NAME: &'static str = "timestamp";
633
634 fn test_from(_: usize) -> Self {
635 unimplemented!("not implemented as input yet")
636 }
637 }
638
639 impl TestFrom for Interval {
640 const NAME: &'static str = "interval";
641
642 fn test_from(i: usize) -> Self {
643 Interval::from_ymd(0, i as _, i as _)
644 }
645 }
646
647 impl TestFrom for Date {
648 const NAME: &'static str = "date";
649
650 fn test_from(i: usize) -> Self {
651 Date::from_num_days_from_ce_uncheck(i as i32)
652 }
653 }
654
655 #[expect(clippy::type_complexity)]
656 fn gen_test_data<L: TestFrom, R: TestFrom, O>(
657 count: usize,
658 f: impl Fn(Option<L>, Option<R>) -> Option<O>,
659 ) -> (Vec<Option<L>>, Vec<Option<R>>, Vec<Option<O>>) {
660 let mut lhs = Vec::<Option<L>>::new();
661 let mut rhs = Vec::<Option<R>>::new();
662 let mut target = Vec::<Option<O>>::new();
663 for i in 0..count {
664 let (l, r) = if i % 2 == 0 {
665 (Some(i), None)
666 } else if i % 3 == 0 {
667 (Some(i), Some(i + 1))
668 } else if i % 5 == 0 {
669 (Some(i + 1), Some(i))
670 } else if i % 7 == 0 {
671 (None, Some(i))
672 } else {
673 (Some(i), Some(i))
674 };
675 let l = l.map(TestFrom::test_from);
676 let r = r.map(TestFrom::test_from);
677 lhs.push(l);
678 rhs.push(r);
679 target.push(f(l, r));
680 }
681 (lhs, rhs, target)
682 }
683
684 fn arithmetic<L, R, O>(f: impl Fn(L, R) -> O) -> impl Fn(Option<L>, Option<R>) -> Option<O> {
685 move |l, r| match (l, r) {
686 (Some(l), Some(r)) => Some(f(l, r)),
687 _ => None,
688 }
689 }
690
691 fn reduce<I>(f: impl Fn(I, I) -> I) -> impl Fn(Option<I>, Option<I>) -> Option<I> {
692 move |l, r| match (l, r) {
693 (Some(l), Some(r)) => Some(f(l, r)),
694 (Some(l), None) => Some(l),
695 (None, Some(r)) => Some(r),
696 (None, None) => None,
697 }
698 }
699
700 async fn test_binary_inner<L, R, A, F>(f: F, kind: Type)
701 where
702 L: Array,
703 L: for<'a> FromIterator<&'a Option<<L as Array>::OwnedItem>>,
704 <L as Array>::OwnedItem: TestFrom,
705 R: Array,
706 R: for<'a> FromIterator<&'a Option<<R as Array>::OwnedItem>>,
707 <R as Array>::OwnedItem: TestFrom,
708 A: Array,
709 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
710 for<'a> <A as Array>::RefItem<'a>: PartialEq,
711 <A as Array>::OwnedItem: TestFrom,
712 F: Fn(
713 Option<<L as Array>::OwnedItem>,
714 Option<<R as Array>::OwnedItem>,
715 ) -> Option<<A as Array>::OwnedItem>,
716 {
717 let (lhs, rhs, target) = gen_test_data(100, f);
718
719 let col1 = L::from_iter(&lhs).into_ref();
720 let col2 = R::from_iter(&rhs).into_ref();
721 let data_chunk = DataChunk::new(vec![col1, col2], 100);
722 let l_name = <<L as Array>::OwnedItem as TestFrom>::NAME;
723 let r_name = <<R as Array>::OwnedItem as TestFrom>::NAME;
724 let output_name = <<A as Array>::OwnedItem as TestFrom>::NAME;
725 let expr = build_from_pretty(format!(
726 "({name}:{output_name} $0:{l_name} $1:{r_name})",
727 name = kind.as_str_name(),
728 ));
729 let res = expr.eval(&data_chunk).await.unwrap();
730 let arr: &A = res.as_ref().into();
731 for (idx, item) in arr.iter().enumerate() {
732 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
733 assert_eq!(x, item);
734 }
735
736 for i in 0..lhs.len() {
737 let row = OwnedRow::new(vec![
738 lhs[i].map(|int| int.to_scalar_value()),
739 rhs[i].map(|int| int.to_scalar_value()),
740 ]);
741 let result = expr.eval_row(&row).await.unwrap();
742 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
743 assert_eq!(result, expected);
744 }
745 }
746
747 async fn test_binary_i32<A, F>(f: F, kind: Type)
748 where
749 A: Array,
750 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
751 for<'a> <A as Array>::RefItem<'a>: PartialEq,
752 <A as Array>::OwnedItem: TestFrom,
753 F: Fn(i32, i32) -> <A as Array>::OwnedItem,
754 {
755 test_binary_inner::<I32Array, I32Array, _, _>(arithmetic(f), kind).await
756 }
757
758 async fn test_binary_interval<A, F>(f: F, kind: Type)
759 where
760 A: Array,
761 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
762 for<'a> <A as Array>::RefItem<'a>: PartialEq,
763 <A as Array>::OwnedItem: TestFrom,
764 F: Fn(Date, Interval) -> <A as Array>::OwnedItem,
765 {
766 test_binary_inner::<DateArray, IntervalArray, _, _>(arithmetic(f), kind).await
767 }
768
769 async fn test_binary_decimal<A, F>(f: F, kind: Type)
770 where
771 A: Array,
772 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
773 for<'a> <A as Array>::RefItem<'a>: PartialEq,
774 <A as Array>::OwnedItem: TestFrom,
775 F: Fn(Decimal, Decimal) -> <A as Array>::OwnedItem,
776 {
777 test_binary_inner::<DecimalArray, DecimalArray, _, _>(arithmetic(f), kind).await
778 }
779}