1use std::fmt::Debug;
16
17use constant_time_eq::constant_time_eq;
18use risingwave_common::array::{Array, BoolArray};
19use risingwave_common::bitmap::Bitmap;
20use risingwave_common::row::Row;
21use risingwave_common::types::{Scalar, ScalarRef, ScalarRefImpl};
22use risingwave_expr::{ExprError, Result, function};
23
24#[function("equal(boolean, boolean) -> boolean", batch_fn = "boolarray_eq")]
25#[function("equal(*int, *int) -> boolean")]
26#[function("equal(decimal, decimal) -> boolean")]
27#[function("equal(*float, *float) -> boolean")]
28#[function("equal(int256, int256) -> boolean")]
29#[function("equal(serial, serial) -> boolean")]
30#[function("equal(date, date) -> boolean")]
31#[function("equal(time, time) -> boolean")]
32#[function("equal(interval, interval) -> boolean")]
33#[function("equal(timestamp, timestamp) -> boolean")]
34#[function("equal(timestamptz, timestamptz) -> boolean")]
35#[function("equal(date, timestamp) -> boolean")]
36#[function("equal(timestamp, date) -> boolean")]
37#[function("equal(time, interval) -> boolean")]
38#[function("equal(interval, time) -> boolean")]
39#[function("equal(varchar, varchar) -> boolean")]
40#[function("equal(bytea, bytea) -> boolean")]
41#[function("equal(jsonb, jsonb) -> boolean")]
42#[function("equal(anyarray, anyarray) -> boolean")]
43#[function("equal(struct, struct) -> boolean")]
44pub fn general_eq<T1, T2, T3>(l: T1, r: T2) -> bool
45where
46 T1: Into<T3> + Debug,
47 T2: Into<T3> + Debug,
48 T3: Ord,
49{
50 l.into() == r.into()
51}
52
53#[function("not_equal(boolean, boolean) -> boolean", batch_fn = "boolarray_ne")]
54#[function("not_equal(*int, *int) -> boolean")]
55#[function("not_equal(decimal, decimal) -> boolean")]
56#[function("not_equal(*float, *float) -> boolean")]
57#[function("not_equal(int256, int256) -> boolean")]
58#[function("not_equal(serial, serial) -> boolean")]
59#[function("not_equal(date, date) -> boolean")]
60#[function("not_equal(time, time) -> boolean")]
61#[function("not_equal(interval, interval) -> boolean")]
62#[function("not_equal(timestamp, timestamp) -> boolean")]
63#[function("not_equal(timestamptz, timestamptz) -> boolean")]
64#[function("not_equal(date, timestamp) -> boolean")]
65#[function("not_equal(timestamp, date) -> boolean")]
66#[function("not_equal(time, interval) -> boolean")]
67#[function("not_equal(interval, time) -> boolean")]
68#[function("not_equal(varchar, varchar) -> boolean")]
69#[function("not_equal(bytea, bytea) -> boolean")]
70#[function("not_equal(jsonb, jsonb) -> boolean")]
71#[function("not_equal(anyarray, anyarray) -> boolean")]
72#[function("not_equal(struct, struct) -> boolean")]
73pub fn general_ne<T1, T2, T3>(l: T1, r: T2) -> bool
74where
75 T1: Into<T3> + Debug,
76 T2: Into<T3> + Debug,
77 T3: Ord,
78{
79 l.into() != r.into()
80}
81
82#[function(
83 "greater_than_or_equal(boolean, boolean) -> boolean",
84 batch_fn = "boolarray_ge"
85)]
86#[function("greater_than_or_equal(*int, *int) -> boolean")]
87#[function("greater_than_or_equal(decimal, decimal) -> boolean")]
88#[function("greater_than_or_equal(*float, *float) -> boolean")]
89#[function("greater_than_or_equal(serial, serial) -> boolean")]
90#[function("greater_than_or_equal(int256, int256) -> boolean")]
91#[function("greater_than_or_equal(date, date) -> boolean")]
92#[function("greater_than_or_equal(time, time) -> boolean")]
93#[function("greater_than_or_equal(interval, interval) -> boolean")]
94#[function("greater_than_or_equal(timestamp, timestamp) -> boolean")]
95#[function("greater_than_or_equal(timestamptz, timestamptz) -> boolean")]
96#[function("greater_than_or_equal(date, timestamp) -> boolean")]
97#[function("greater_than_or_equal(timestamp, date) -> boolean")]
98#[function("greater_than_or_equal(time, interval) -> boolean")]
99#[function("greater_than_or_equal(interval, time) -> boolean")]
100#[function("greater_than_or_equal(varchar, varchar) -> boolean")]
101#[function("greater_than_or_equal(bytea, bytea) -> boolean")]
102#[function("greater_than_or_equal(anyarray, anyarray) -> boolean")]
103#[function("greater_than_or_equal(struct, struct) -> boolean")]
104pub fn general_ge<T1, T2, T3>(l: T1, r: T2) -> bool
105where
106 T1: Into<T3> + Debug,
107 T2: Into<T3> + Debug,
108 T3: Ord,
109{
110 l.into() >= r.into()
111}
112
113#[function("greater_than(boolean, boolean) -> boolean", batch_fn = "boolarray_gt")]
114#[function("greater_than(*int, *int) -> boolean")]
115#[function("greater_than(decimal, decimal) -> boolean")]
116#[function("greater_than(*float, *float) -> boolean")]
117#[function("greater_than(serial, serial) -> boolean")]
118#[function("greater_than(int256, int256) -> boolean")]
119#[function("greater_than(date, date) -> boolean")]
120#[function("greater_than(time, time) -> boolean")]
121#[function("greater_than(interval, interval) -> boolean")]
122#[function("greater_than(timestamp, timestamp) -> boolean")]
123#[function("greater_than(timestamptz, timestamptz) -> boolean")]
124#[function("greater_than(date, timestamp) -> boolean")]
125#[function("greater_than(timestamp, date) -> boolean")]
126#[function("greater_than(time, interval) -> boolean")]
127#[function("greater_than(interval, time) -> boolean")]
128#[function("greater_than(varchar, varchar) -> boolean")]
129#[function("greater_than(bytea, bytea) -> boolean")]
130#[function("greater_than(anyarray, anyarray) -> boolean")]
131#[function("greater_than(struct, struct) -> boolean")]
132pub fn general_gt<T1, T2, T3>(l: T1, r: T2) -> bool
133where
134 T1: Into<T3> + Debug,
135 T2: Into<T3> + Debug,
136 T3: Ord,
137{
138 l.into() > r.into()
139}
140
141#[function(
142 "less_than_or_equal(boolean, boolean) -> boolean",
143 batch_fn = "boolarray_le"
144)]
145#[function("less_than_or_equal(*int, *int) -> boolean")]
146#[function("less_than_or_equal(decimal, decimal) -> boolean")]
147#[function("less_than_or_equal(*float, *float) -> boolean")]
148#[function("less_than_or_equal(serial, serial) -> boolean")]
149#[function("less_than_or_equal(int256, int256) -> boolean")]
150#[function("less_than_or_equal(date, date) -> boolean")]
151#[function("less_than_or_equal(time, time) -> boolean")]
152#[function("less_than_or_equal(interval, interval) -> boolean")]
153#[function("less_than_or_equal(timestamp, timestamp) -> boolean")]
154#[function("less_than_or_equal(timestamptz, timestamptz) -> boolean")]
155#[function("less_than_or_equal(date, timestamp) -> boolean")]
156#[function("less_than_or_equal(timestamp, date) -> boolean")]
157#[function("less_than_or_equal(time, interval) -> boolean")]
158#[function("less_than_or_equal(interval, time) -> boolean")]
159#[function("less_than_or_equal(varchar, varchar) -> boolean")]
160#[function("less_than_or_equal(bytea, bytea) -> boolean")]
161#[function("less_than_or_equal(anyarray, anyarray) -> boolean")]
162#[function("less_than_or_equal(struct, struct) -> boolean")]
163pub fn general_le<T1, T2, T3>(l: T1, r: T2) -> bool
164where
165 T1: Into<T3> + Debug,
166 T2: Into<T3> + Debug,
167 T3: Ord,
168{
169 l.into() <= r.into()
170}
171
172#[function("less_than(boolean, boolean) -> boolean", batch_fn = "boolarray_lt")]
173#[function("less_than(*int, *int) -> boolean")]
174#[function("less_than(decimal, decimal) -> boolean")]
175#[function("less_than(*float, *float) -> boolean")]
176#[function("less_than(serial, serial) -> boolean")]
177#[function("less_than(int256, int256) -> boolean")]
178#[function("less_than(date, date) -> boolean")]
179#[function("less_than(time, time) -> boolean")]
180#[function("less_than(interval, interval) -> boolean")]
181#[function("less_than(timestamp, timestamp) -> boolean")]
182#[function("less_than(timestamptz, timestamptz) -> boolean")]
183#[function("less_than(date, timestamp) -> boolean")]
184#[function("less_than(timestamp, date) -> boolean")]
185#[function("less_than(time, interval) -> boolean")]
186#[function("less_than(interval, time) -> boolean")]
187#[function("less_than(varchar, varchar) -> boolean")]
188#[function("less_than(bytea, bytea) -> boolean")]
189#[function("less_than(anyarray, anyarray) -> boolean")]
190#[function("less_than(struct, struct) -> boolean")]
191pub fn general_lt<T1, T2, T3>(l: T1, r: T2) -> bool
192where
193 T1: Into<T3> + Debug,
194 T2: Into<T3> + Debug,
195 T3: Ord,
196{
197 l.into() < r.into()
198}
199
200#[function(
201 "is_distinct_from(boolean, boolean) -> boolean",
202 batch_fn = "boolarray_is_distinct_from"
203)]
204#[function("is_distinct_from(*int, *int) -> boolean")]
205#[function("is_distinct_from(decimal, decimal) -> boolean")]
206#[function("is_distinct_from(*float, *float) -> boolean")]
207#[function("is_distinct_from(serial, serial) -> boolean")]
208#[function("is_distinct_from(int256, int256) -> boolean")]
209#[function("is_distinct_from(date, date) -> boolean")]
210#[function("is_distinct_from(time, time) -> boolean")]
211#[function("is_distinct_from(interval, interval) -> boolean")]
212#[function("is_distinct_from(timestamp, timestamp) -> boolean")]
213#[function("is_distinct_from(timestamptz, timestamptz) -> boolean")]
214#[function("is_distinct_from(jsonb, jsonb) -> boolean")]
215#[function("is_distinct_from(date, timestamp) -> boolean")]
216#[function("is_distinct_from(timestamp, date) -> boolean")]
217#[function("is_distinct_from(time, interval) -> boolean")]
218#[function("is_distinct_from(interval, time) -> boolean")]
219#[function("is_distinct_from(varchar, varchar) -> boolean")]
220#[function("is_distinct_from(bytea, bytea) -> boolean")]
221#[function("is_distinct_from(anyarray, anyarray) -> boolean")]
222#[function("is_distinct_from(struct, struct) -> boolean")]
223pub fn general_is_distinct_from<T1, T2, T3>(l: Option<T1>, r: Option<T2>) -> bool
224where
225 T1: Into<T3> + Debug,
226 T2: Into<T3> + Debug,
227 T3: Ord,
228{
229 l.map(Into::into) != r.map(Into::into)
230}
231
232#[function(
233 "is_not_distinct_from(boolean, boolean) -> boolean",
234 batch_fn = "boolarray_is_not_distinct_from"
235)]
236#[function("is_not_distinct_from(*int, *int) -> boolean")]
237#[function("is_not_distinct_from(decimal, decimal) -> boolean")]
238#[function("is_not_distinct_from(*float, *float) -> boolean")]
239#[function("is_not_distinct_from(serial, serial) -> boolean")]
240#[function("is_not_distinct_from(int256, int256) -> boolean")]
241#[function("is_not_distinct_from(date, date) -> boolean")]
242#[function("is_not_distinct_from(time, time) -> boolean")]
243#[function("is_not_distinct_from(interval, interval) -> boolean")]
244#[function("is_not_distinct_from(timestamp, timestamp) -> boolean")]
245#[function("is_not_distinct_from(timestamptz, timestamptz) -> boolean")]
246#[function("is_not_distinct_from(jsonb, jsonb) -> boolean")]
247#[function("is_not_distinct_from(date, timestamp) -> boolean")]
248#[function("is_not_distinct_from(timestamp, date) -> boolean")]
249#[function("is_not_distinct_from(time, interval) -> boolean")]
250#[function("is_not_distinct_from(interval, time) -> boolean")]
251#[function("is_not_distinct_from(varchar, varchar) -> boolean")]
252#[function("is_not_distinct_from(bytea, bytea) -> boolean")]
253#[function("is_not_distinct_from(anyarray, anyarray) -> boolean")]
254#[function("is_not_distinct_from(struct, struct) -> boolean")]
255pub fn general_is_not_distinct_from<T1, T2, T3>(l: Option<T1>, r: Option<T2>) -> bool
256where
257 T1: Into<T3> + Debug,
258 T2: Into<T3> + Debug,
259 T3: Ord,
260{
261 l.map(Into::into) == r.map(Into::into)
262}
263
264#[function("is_true(boolean) -> boolean", batch_fn = "boolarray_is_true")]
265pub fn is_true(v: Option<bool>) -> bool {
266 v == Some(true)
267}
268
269#[function("is_not_true(boolean) -> boolean", batch_fn = "boolarray_is_not_true")]
270pub fn is_not_true(v: Option<bool>) -> bool {
271 v != Some(true)
272}
273
274#[function("is_false(boolean) -> boolean", batch_fn = "boolarray_is_false")]
275pub fn is_false(v: Option<bool>) -> bool {
276 v == Some(false)
277}
278
279#[function(
280 "is_not_false(boolean) -> boolean",
281 batch_fn = "boolarray_is_not_false"
282)]
283pub fn is_not_false(v: Option<bool>) -> bool {
284 v != Some(false)
285}
286
287#[function("is_null(*) -> boolean", batch_fn = "batch_is_null")]
288fn is_null<T>(v: Option<T>) -> bool {
289 v.is_none()
290}
291
292#[function("is_not_null(*) -> boolean", batch_fn = "batch_is_not_null")]
293fn is_not_null<T>(v: Option<T>) -> bool {
294 v.is_some()
295}
296
297#[function("greatest(...) -> boolean")]
298#[function("greatest(...) -> *int")]
299#[function("greatest(...) -> decimal")]
300#[function("greatest(...) -> *float")]
301#[function("greatest(...) -> serial")]
302#[function("greatest(...) -> int256")]
303#[function("greatest(...) -> date")]
304#[function("greatest(...) -> time")]
305#[function("greatest(...) -> interval")]
306#[function("greatest(...) -> timestamp")]
307#[function("greatest(...) -> timestamptz")]
308#[function("greatest(...) -> varchar")]
309#[function("greatest(...) -> bytea")]
310pub fn general_variadic_greatest<T>(row: impl Row) -> Option<T>
311where
312 T: Scalar,
313 for<'a> <T as Scalar>::ScalarRefType<'a>: TryFrom<ScalarRefImpl<'a>> + Ord + Debug,
314{
315 row.iter()
316 .flatten()
317 .map(
318 |scalar| match <<T as Scalar>::ScalarRefType<'_>>::try_from(scalar) {
319 Ok(v) => v,
320 Err(_) => unreachable!("all input type should have been aligned in the frontend"),
321 },
322 )
323 .max()
324 .map(|v| v.to_owned_scalar())
325}
326
327#[function("least(...) -> boolean")]
328#[function("least(...) -> *int")]
329#[function("least(...) -> decimal")]
330#[function("least(...) -> *float")]
331#[function("least(...) -> serial")]
332#[function("least(...) -> int256")]
333#[function("least(...) -> date")]
334#[function("least(...) -> time")]
335#[function("least(...) -> interval")]
336#[function("least(...) -> timestamp")]
337#[function("least(...) -> timestamptz")]
338#[function("least(...) -> varchar")]
339#[function("least(...) -> bytea")]
340pub fn general_variadic_least<T>(row: impl Row) -> Option<T>
341where
342 T: Scalar,
343 for<'a> <T as Scalar>::ScalarRefType<'a>: TryFrom<ScalarRefImpl<'a>> + Ord + Debug,
344{
345 row.iter()
346 .flatten()
347 .map(
348 |scalar| match <<T as Scalar>::ScalarRefType<'_>>::try_from(scalar) {
349 Ok(v) => v,
350 Err(_) => unreachable!("all input type should have been aligned in the frontend"),
351 },
352 )
353 .min()
354 .map(|v| v.to_owned_scalar())
355}
356
357fn boolarray_eq(l: &BoolArray, r: &BoolArray) -> BoolArray {
360 let data = !(l.data() ^ r.data());
361 let bitmap = l.null_bitmap() & r.null_bitmap();
362 BoolArray::new(data, bitmap)
363}
364
365fn boolarray_ne(l: &BoolArray, r: &BoolArray) -> BoolArray {
366 let data = l.data() ^ r.data();
367 let bitmap = l.null_bitmap() & r.null_bitmap();
368 BoolArray::new(data, bitmap)
369}
370
371fn boolarray_gt(l: &BoolArray, r: &BoolArray) -> BoolArray {
372 let data = l.data() & !r.data();
373 let bitmap = l.null_bitmap() & r.null_bitmap();
374 BoolArray::new(data, bitmap)
375}
376
377fn boolarray_lt(l: &BoolArray, r: &BoolArray) -> BoolArray {
378 let data = !l.data() & r.data();
379 let bitmap = l.null_bitmap() & r.null_bitmap();
380 BoolArray::new(data, bitmap)
381}
382
383fn boolarray_ge(l: &BoolArray, r: &BoolArray) -> BoolArray {
384 let data = l.data() | !r.data();
385 let bitmap = l.null_bitmap() & r.null_bitmap();
386 BoolArray::new(data, bitmap)
387}
388
389fn boolarray_le(l: &BoolArray, r: &BoolArray) -> BoolArray {
390 let data = !l.data() | r.data();
391 let bitmap = l.null_bitmap() & r.null_bitmap();
392 BoolArray::new(data, bitmap)
393}
394
395fn boolarray_is_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray {
396 let data = ((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap()))
397 | (l.null_bitmap() ^ r.null_bitmap());
398 BoolArray::new(data, Bitmap::ones(l.len()))
399}
400
401fn boolarray_is_not_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray {
402 let data = !(((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap()))
403 | (l.null_bitmap() ^ r.null_bitmap()));
404 BoolArray::new(data, Bitmap::ones(l.len()))
405}
406
407fn boolarray_is_true(a: &BoolArray) -> BoolArray {
408 BoolArray::new(a.to_bitmap(), Bitmap::ones(a.len()))
409}
410
411fn boolarray_is_not_true(a: &BoolArray) -> BoolArray {
412 BoolArray::new(!a.to_bitmap(), Bitmap::ones(a.len()))
413}
414
415fn boolarray_is_false(a: &BoolArray) -> BoolArray {
416 BoolArray::new(!a.data() & a.null_bitmap(), Bitmap::ones(a.len()))
417}
418
419fn boolarray_is_not_false(a: &BoolArray) -> BoolArray {
420 BoolArray::new(a.data() | !a.null_bitmap(), Bitmap::ones(a.len()))
421}
422
423fn batch_is_null(a: &impl Array) -> BoolArray {
424 BoolArray::new(!a.null_bitmap(), Bitmap::ones(a.len()))
425}
426
427fn batch_is_not_null(a: &impl Array) -> BoolArray {
428 BoolArray::new(a.null_bitmap().clone(), Bitmap::ones(a.len()))
429}
430
431#[function("secure_compare(varchar, varchar) -> boolean")]
432pub fn secure_compare(left: &str, right: &str) -> bool {
433 constant_time_eq(left.as_bytes(), right.as_bytes())
434}
435
436#[function("check_not_null(any, varchar, varchar) -> any")]
437fn check_not_null<'a>(
438 v: Option<ScalarRefImpl<'a>>,
439 col_name: &str,
440 relation_name: &str,
441) -> Result<Option<ScalarRefImpl<'a>>> {
442 if v.is_none() {
443 return Err(ExprError::NotNullViolation {
444 col_name: col_name.into(),
445 table_name: relation_name.into(),
446 });
447 }
448 Ok(v)
449}
450
451#[cfg(test)]
452mod tests {
453 use std::str::FromStr;
454
455 use risingwave_common::types::{Decimal, F32, F64, Timestamp};
456 use risingwave_expr::expr::build_from_pretty;
457
458 use super::*;
459
460 #[test]
461 fn test_comparison() {
462 assert!(general_eq::<Decimal, i32, Decimal>(dec("1.0"), 1));
463 assert!(!general_ne::<Decimal, i32, Decimal>(dec("1.0"), 1));
464 assert!(!general_gt::<Decimal, i32, Decimal>(dec("1.0"), 2));
465 assert!(general_le::<Decimal, i32, Decimal>(dec("1.0"), 2));
466 assert!(!general_ge::<Decimal, i32, Decimal>(dec("1.0"), 2));
467 assert!(general_lt::<Decimal, i32, Decimal>(dec("1.0"), 2));
468 assert!(general_is_distinct_from::<Decimal, i32, Decimal>(
469 Some(dec("1.0")),
470 Some(2)
471 ));
472 assert!(general_is_distinct_from::<Decimal, i32, Decimal>(
473 None,
474 Some(1)
475 ));
476 assert!(!general_is_distinct_from::<Decimal, i32, Decimal>(
477 Some(dec("1.0")),
478 Some(1)
479 ));
480 assert!(general_eq::<F32, i32, F64>(1.0.into(), 1));
481 assert!(!general_ne::<F32, i32, F64>(1.0.into(), 1));
482 assert!(!general_lt::<F32, i32, F64>(1.0.into(), 1));
483 assert!(general_le::<F32, i32, F64>(1.0.into(), 1));
484 assert!(!general_gt::<F32, i32, F64>(1.0.into(), 1));
485 assert!(general_ge::<F32, i32, F64>(1.0.into(), 1));
486 assert!(!general_is_distinct_from::<F32, i32, F64>(
487 Some(1.0.into()),
488 Some(1)
489 ));
490 assert!(general_eq::<i64, i32, i64>(1i64, 1));
491 assert!(!general_ne::<i64, i32, i64>(1i64, 1));
492 assert!(!general_lt::<i64, i32, i64>(1i64, 1));
493 assert!(general_le::<i64, i32, i64>(1i64, 1));
494 assert!(!general_gt::<i64, i32, i64>(1i64, 1));
495 assert!(general_ge::<i64, i32, i64>(1i64, 1));
496 assert!(!general_is_distinct_from::<i64, i32, i64>(
497 Some(1i64),
498 Some(1)
499 ));
500 }
501
502 fn dec(s: &str) -> Decimal {
503 Decimal::from_str(s).unwrap()
504 }
505
506 #[tokio::test]
507 async fn test_is_distinct_from() {
508 let (input, target) = DataChunk::from_pretty(
509 "
510 i i B
511 . . f
512 . 1 t
513 1 . t
514 2 2 f
515 3 4 t
516 ",
517 )
518 .split_column_at(2);
519 let expr = build_from_pretty("(is_distinct_from:boolean $0:int4 $1:int4)");
520 let result = expr.eval(&input).await.unwrap();
521 assert_eq!(&result, target.column_at(0));
522 }
523
524 #[tokio::test]
525 async fn test_is_not_distinct_from() {
526 let (input, target) = DataChunk::from_pretty(
527 "
528 i i B
529 . . t
530 . 1 f
531 1 . f
532 2 2 t
533 3 4 f
534 ",
535 )
536 .split_column_at(2);
537 let expr = build_from_pretty("(is_not_distinct_from:boolean $0:int4 $1:int4)");
538 let result = expr.eval(&input).await.unwrap();
539 assert_eq!(&result, target.column_at(0));
540 }
541
542 use risingwave_common::array::*;
543 use risingwave_common::row::OwnedRow;
544 use risingwave_common::types::test_utils::IntervalTestExt;
545 use risingwave_common::types::{Date, Interval};
546 use risingwave_pb::expr::expr_node::Type;
547
548 use crate::scalar::arithmetic_op::{date_interval_add, date_interval_sub};
549
550 #[tokio::test]
551 async fn test_binary() {
552 test_binary_i32::<I32Array, _>(|x, y| x + y, Type::Add).await;
553 test_binary_i32::<I32Array, _>(|x, y| x - y, Type::Subtract).await;
554 test_binary_i32::<I32Array, _>(|x, y| x * y, Type::Multiply).await;
555 test_binary_i32::<I32Array, _>(|x, y| x / y, Type::Divide).await;
556 test_binary_i32::<BoolArray, _>(|x, y| x == y, Type::Equal).await;
557 test_binary_i32::<BoolArray, _>(|x, y| x != y, Type::NotEqual).await;
558 test_binary_i32::<BoolArray, _>(|x, y| x > y, Type::GreaterThan).await;
559 test_binary_i32::<BoolArray, _>(|x, y| x >= y, Type::GreaterThanOrEqual).await;
560 test_binary_i32::<BoolArray, _>(|x, y| x < y, Type::LessThan).await;
561 test_binary_i32::<BoolArray, _>(|x, y| x <= y, Type::LessThanOrEqual).await;
562 test_binary_inner::<I32Array, I32Array, I32Array, _>(
563 reduce(std::cmp::max::<i32>),
564 Type::Greatest,
565 )
566 .await;
567 test_binary_inner::<I32Array, I32Array, I32Array, _>(
568 reduce(std::cmp::min::<i32>),
569 Type::Least,
570 )
571 .await;
572 test_binary_inner::<BoolArray, BoolArray, BoolArray, _>(
573 reduce(std::cmp::max::<bool>),
574 Type::Greatest,
575 )
576 .await;
577 test_binary_inner::<BoolArray, BoolArray, BoolArray, _>(
578 reduce(std::cmp::min::<bool>),
579 Type::Least,
580 )
581 .await;
582 test_binary_decimal::<DecimalArray, _>(|x, y| x + y, Type::Add).await;
583 test_binary_decimal::<DecimalArray, _>(|x, y| x - y, Type::Subtract).await;
584 test_binary_decimal::<DecimalArray, _>(|x, y| x * y, Type::Multiply).await;
585 test_binary_decimal::<DecimalArray, _>(|x, y| x / y, Type::Divide).await;
586 test_binary_decimal::<BoolArray, _>(|x, y| x == y, Type::Equal).await;
587 test_binary_decimal::<BoolArray, _>(|x, y| x != y, Type::NotEqual).await;
588 test_binary_decimal::<BoolArray, _>(|x, y| x > y, Type::GreaterThan).await;
589 test_binary_decimal::<BoolArray, _>(|x, y| x >= y, Type::GreaterThanOrEqual).await;
590 test_binary_decimal::<BoolArray, _>(|x, y| x < y, Type::LessThan).await;
591 test_binary_decimal::<BoolArray, _>(|x, y| x <= y, Type::LessThanOrEqual).await;
592 test_binary_interval::<TimestampArray, _>(
593 |x, y| date_interval_add(x, y).unwrap(),
594 Type::Add,
595 )
596 .await;
597 test_binary_interval::<TimestampArray, _>(
598 |x, y| date_interval_sub(x, y).unwrap(),
599 Type::Subtract,
600 )
601 .await;
602 }
603
604 trait TestFrom: Copy {
605 const NAME: &'static str;
606 fn test_from(i: usize) -> Self;
607 }
608
609 impl TestFrom for i32 {
610 const NAME: &'static str = "int4";
611
612 fn test_from(i: usize) -> Self {
613 i as i32
614 }
615 }
616
617 impl TestFrom for Decimal {
618 const NAME: &'static str = "decimal";
619
620 fn test_from(i: usize) -> Self {
621 i.into()
622 }
623 }
624
625 impl TestFrom for bool {
626 const NAME: &'static str = "boolean";
627
628 fn test_from(i: usize) -> Self {
629 i % 2 == 0
630 }
631 }
632
633 impl TestFrom for Timestamp {
634 const NAME: &'static str = "timestamp";
635
636 fn test_from(_: usize) -> Self {
637 unimplemented!("not implemented as input yet")
638 }
639 }
640
641 impl TestFrom for Interval {
642 const NAME: &'static str = "interval";
643
644 fn test_from(i: usize) -> Self {
645 Interval::from_ymd(0, i as _, i as _)
646 }
647 }
648
649 impl TestFrom for Date {
650 const NAME: &'static str = "date";
651
652 fn test_from(i: usize) -> Self {
653 Date::from_num_days_from_ce_uncheck(i as i32)
654 }
655 }
656
657 #[expect(clippy::type_complexity)]
658 fn gen_test_data<L: TestFrom, R: TestFrom, O>(
659 count: usize,
660 f: impl Fn(Option<L>, Option<R>) -> Option<O>,
661 ) -> (Vec<Option<L>>, Vec<Option<R>>, Vec<Option<O>>) {
662 let mut lhs = Vec::<Option<L>>::new();
663 let mut rhs = Vec::<Option<R>>::new();
664 let mut target = Vec::<Option<O>>::new();
665 for i in 0..count {
666 let (l, r) = if i % 2 == 0 {
667 (Some(i), None)
668 } else if i % 3 == 0 {
669 (Some(i), Some(i + 1))
670 } else if i % 5 == 0 {
671 (Some(i + 1), Some(i))
672 } else if i % 7 == 0 {
673 (None, Some(i))
674 } else {
675 (Some(i), Some(i))
676 };
677 let l = l.map(TestFrom::test_from);
678 let r = r.map(TestFrom::test_from);
679 lhs.push(l);
680 rhs.push(r);
681 target.push(f(l, r));
682 }
683 (lhs, rhs, target)
684 }
685
686 fn arithmetic<L, R, O>(f: impl Fn(L, R) -> O) -> impl Fn(Option<L>, Option<R>) -> Option<O> {
687 move |l, r| match (l, r) {
688 (Some(l), Some(r)) => Some(f(l, r)),
689 _ => None,
690 }
691 }
692
693 fn reduce<I>(f: impl Fn(I, I) -> I) -> impl Fn(Option<I>, Option<I>) -> Option<I> {
694 move |l, r| match (l, r) {
695 (Some(l), Some(r)) => Some(f(l, r)),
696 (Some(l), None) => Some(l),
697 (None, Some(r)) => Some(r),
698 (None, None) => None,
699 }
700 }
701
702 async fn test_binary_inner<L, R, A, F>(f: F, kind: Type)
703 where
704 L: Array,
705 L: for<'a> FromIterator<&'a Option<<L as Array>::OwnedItem>>,
706 <L as Array>::OwnedItem: TestFrom,
707 R: Array,
708 R: for<'a> FromIterator<&'a Option<<R as Array>::OwnedItem>>,
709 <R as Array>::OwnedItem: TestFrom,
710 A: Array,
711 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
712 for<'a> <A as Array>::RefItem<'a>: PartialEq,
713 <A as Array>::OwnedItem: TestFrom,
714 F: Fn(
715 Option<<L as Array>::OwnedItem>,
716 Option<<R as Array>::OwnedItem>,
717 ) -> Option<<A as Array>::OwnedItem>,
718 {
719 let (lhs, rhs, target) = gen_test_data(100, f);
720
721 let col1 = L::from_iter(&lhs).into_ref();
722 let col2 = R::from_iter(&rhs).into_ref();
723 let data_chunk = DataChunk::new(vec![col1, col2], 100);
724 let l_name = <<L as Array>::OwnedItem as TestFrom>::NAME;
725 let r_name = <<R as Array>::OwnedItem as TestFrom>::NAME;
726 let output_name = <<A as Array>::OwnedItem as TestFrom>::NAME;
727 let expr = build_from_pretty(format!(
728 "({name}:{output_name} $0:{l_name} $1:{r_name})",
729 name = kind.as_str_name(),
730 ));
731 let res = expr.eval(&data_chunk).await.unwrap();
732 let arr: &A = res.as_ref().into();
733 for (idx, item) in arr.iter().enumerate() {
734 let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
735 assert_eq!(x, item);
736 }
737
738 for i in 0..lhs.len() {
739 let row = OwnedRow::new(vec![
740 lhs[i].map(|int| int.to_scalar_value()),
741 rhs[i].map(|int| int.to_scalar_value()),
742 ]);
743 let result = expr.eval_row(&row).await.unwrap();
744 let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
745 assert_eq!(result, expected);
746 }
747 }
748
749 async fn test_binary_i32<A, F>(f: F, kind: Type)
750 where
751 A: Array,
752 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
753 for<'a> <A as Array>::RefItem<'a>: PartialEq,
754 <A as Array>::OwnedItem: TestFrom,
755 F: Fn(i32, i32) -> <A as Array>::OwnedItem,
756 {
757 test_binary_inner::<I32Array, I32Array, _, _>(arithmetic(f), kind).await
758 }
759
760 async fn test_binary_interval<A, F>(f: F, kind: Type)
761 where
762 A: Array,
763 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
764 for<'a> <A as Array>::RefItem<'a>: PartialEq,
765 <A as Array>::OwnedItem: TestFrom,
766 F: Fn(Date, Interval) -> <A as Array>::OwnedItem,
767 {
768 test_binary_inner::<DateArray, IntervalArray, _, _>(arithmetic(f), kind).await
769 }
770
771 async fn test_binary_decimal<A, F>(f: F, kind: Type)
772 where
773 A: Array,
774 for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
775 for<'a> <A as Array>::RefItem<'a>: PartialEq,
776 <A as Array>::OwnedItem: TestFrom,
777 F: Fn(Decimal, Decimal) -> <A as Array>::OwnedItem,
778 {
779 test_binary_inner::<DecimalArray, DecimalArray, _, _>(arithmetic(f), kind).await
780 }
781}