risingwave_expr_impl/scalar/
cmp.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16
17use constant_time_eq::constant_time_eq;
18use risingwave_common::array::{Array, BoolArray};
19use risingwave_common::bitmap::Bitmap;
20use risingwave_common::row::Row;
21use risingwave_common::types::{Scalar, ScalarRef, ScalarRefImpl};
22use risingwave_expr::{ExprError, Result, function};
23
24#[function("equal(boolean, boolean) -> boolean", batch_fn = "boolarray_eq")]
25#[function("equal(*int, *int) -> boolean")]
26#[function("equal(decimal, decimal) -> boolean")]
27#[function("equal(*float, *float) -> boolean")]
28#[function("equal(int256, int256) -> boolean")]
29#[function("equal(serial, serial) -> boolean")]
30#[function("equal(date, date) -> boolean")]
31#[function("equal(time, time) -> boolean")]
32#[function("equal(interval, interval) -> boolean")]
33#[function("equal(timestamp, timestamp) -> boolean")]
34#[function("equal(timestamptz, timestamptz) -> boolean")]
35#[function("equal(date, timestamp) -> boolean")]
36#[function("equal(timestamp, date) -> boolean")]
37#[function("equal(time, interval) -> boolean")]
38#[function("equal(interval, time) -> boolean")]
39#[function("equal(varchar, varchar) -> boolean")]
40#[function("equal(bytea, bytea) -> boolean")]
41#[function("equal(jsonb, jsonb) -> boolean")]
42#[function("equal(anyarray, anyarray) -> boolean")]
43#[function("equal(struct, struct) -> boolean")]
44pub fn general_eq<T1, T2, T3>(l: T1, r: T2) -> bool
45where
46    T1: Into<T3> + Debug,
47    T2: Into<T3> + Debug,
48    T3: Ord,
49{
50    l.into() == r.into()
51}
52
53#[function("not_equal(boolean, boolean) -> boolean", batch_fn = "boolarray_ne")]
54#[function("not_equal(*int, *int) -> boolean")]
55#[function("not_equal(decimal, decimal) -> boolean")]
56#[function("not_equal(*float, *float) -> boolean")]
57#[function("not_equal(int256, int256) -> boolean")]
58#[function("not_equal(serial, serial) -> boolean")]
59#[function("not_equal(date, date) -> boolean")]
60#[function("not_equal(time, time) -> boolean")]
61#[function("not_equal(interval, interval) -> boolean")]
62#[function("not_equal(timestamp, timestamp) -> boolean")]
63#[function("not_equal(timestamptz, timestamptz) -> boolean")]
64#[function("not_equal(date, timestamp) -> boolean")]
65#[function("not_equal(timestamp, date) -> boolean")]
66#[function("not_equal(time, interval) -> boolean")]
67#[function("not_equal(interval, time) -> boolean")]
68#[function("not_equal(varchar, varchar) -> boolean")]
69#[function("not_equal(bytea, bytea) -> boolean")]
70#[function("not_equal(jsonb, jsonb) -> boolean")]
71#[function("not_equal(anyarray, anyarray) -> boolean")]
72#[function("not_equal(struct, struct) -> boolean")]
73pub fn general_ne<T1, T2, T3>(l: T1, r: T2) -> bool
74where
75    T1: Into<T3> + Debug,
76    T2: Into<T3> + Debug,
77    T3: Ord,
78{
79    l.into() != r.into()
80}
81
82#[function(
83    "greater_than_or_equal(boolean, boolean) -> boolean",
84    batch_fn = "boolarray_ge"
85)]
86#[function("greater_than_or_equal(*int, *int) -> boolean")]
87#[function("greater_than_or_equal(decimal, decimal) -> boolean")]
88#[function("greater_than_or_equal(*float, *float) -> boolean")]
89#[function("greater_than_or_equal(serial, serial) -> boolean")]
90#[function("greater_than_or_equal(int256, int256) -> boolean")]
91#[function("greater_than_or_equal(date, date) -> boolean")]
92#[function("greater_than_or_equal(time, time) -> boolean")]
93#[function("greater_than_or_equal(interval, interval) -> boolean")]
94#[function("greater_than_or_equal(timestamp, timestamp) -> boolean")]
95#[function("greater_than_or_equal(timestamptz, timestamptz) -> boolean")]
96#[function("greater_than_or_equal(date, timestamp) -> boolean")]
97#[function("greater_than_or_equal(timestamp, date) -> boolean")]
98#[function("greater_than_or_equal(time, interval) -> boolean")]
99#[function("greater_than_or_equal(interval, time) -> boolean")]
100#[function("greater_than_or_equal(varchar, varchar) -> boolean")]
101#[function("greater_than_or_equal(bytea, bytea) -> boolean")]
102#[function("greater_than_or_equal(anyarray, anyarray) -> boolean")]
103#[function("greater_than_or_equal(struct, struct) -> boolean")]
104pub fn general_ge<T1, T2, T3>(l: T1, r: T2) -> bool
105where
106    T1: Into<T3> + Debug,
107    T2: Into<T3> + Debug,
108    T3: Ord,
109{
110    l.into() >= r.into()
111}
112
113#[function("greater_than(boolean, boolean) -> boolean", batch_fn = "boolarray_gt")]
114#[function("greater_than(*int, *int) -> boolean")]
115#[function("greater_than(decimal, decimal) -> boolean")]
116#[function("greater_than(*float, *float) -> boolean")]
117#[function("greater_than(serial, serial) -> boolean")]
118#[function("greater_than(int256, int256) -> boolean")]
119#[function("greater_than(date, date) -> boolean")]
120#[function("greater_than(time, time) -> boolean")]
121#[function("greater_than(interval, interval) -> boolean")]
122#[function("greater_than(timestamp, timestamp) -> boolean")]
123#[function("greater_than(timestamptz, timestamptz) -> boolean")]
124#[function("greater_than(date, timestamp) -> boolean")]
125#[function("greater_than(timestamp, date) -> boolean")]
126#[function("greater_than(time, interval) -> boolean")]
127#[function("greater_than(interval, time) -> boolean")]
128#[function("greater_than(varchar, varchar) -> boolean")]
129#[function("greater_than(bytea, bytea) -> boolean")]
130#[function("greater_than(anyarray, anyarray) -> boolean")]
131#[function("greater_than(struct, struct) -> boolean")]
132pub fn general_gt<T1, T2, T3>(l: T1, r: T2) -> bool
133where
134    T1: Into<T3> + Debug,
135    T2: Into<T3> + Debug,
136    T3: Ord,
137{
138    l.into() > r.into()
139}
140
141#[function(
142    "less_than_or_equal(boolean, boolean) -> boolean",
143    batch_fn = "boolarray_le"
144)]
145#[function("less_than_or_equal(*int, *int) -> boolean")]
146#[function("less_than_or_equal(decimal, decimal) -> boolean")]
147#[function("less_than_or_equal(*float, *float) -> boolean")]
148#[function("less_than_or_equal(serial, serial) -> boolean")]
149#[function("less_than_or_equal(int256, int256) -> boolean")]
150#[function("less_than_or_equal(date, date) -> boolean")]
151#[function("less_than_or_equal(time, time) -> boolean")]
152#[function("less_than_or_equal(interval, interval) -> boolean")]
153#[function("less_than_or_equal(timestamp, timestamp) -> boolean")]
154#[function("less_than_or_equal(timestamptz, timestamptz) -> boolean")]
155#[function("less_than_or_equal(date, timestamp) -> boolean")]
156#[function("less_than_or_equal(timestamp, date) -> boolean")]
157#[function("less_than_or_equal(time, interval) -> boolean")]
158#[function("less_than_or_equal(interval, time) -> boolean")]
159#[function("less_than_or_equal(varchar, varchar) -> boolean")]
160#[function("less_than_or_equal(bytea, bytea) -> boolean")]
161#[function("less_than_or_equal(anyarray, anyarray) -> boolean")]
162#[function("less_than_or_equal(struct, struct) -> boolean")]
163pub fn general_le<T1, T2, T3>(l: T1, r: T2) -> bool
164where
165    T1: Into<T3> + Debug,
166    T2: Into<T3> + Debug,
167    T3: Ord,
168{
169    l.into() <= r.into()
170}
171
172#[function("less_than(boolean, boolean) -> boolean", batch_fn = "boolarray_lt")]
173#[function("less_than(*int, *int) -> boolean")]
174#[function("less_than(decimal, decimal) -> boolean")]
175#[function("less_than(*float, *float) -> boolean")]
176#[function("less_than(serial, serial) -> boolean")]
177#[function("less_than(int256, int256) -> boolean")]
178#[function("less_than(date, date) -> boolean")]
179#[function("less_than(time, time) -> boolean")]
180#[function("less_than(interval, interval) -> boolean")]
181#[function("less_than(timestamp, timestamp) -> boolean")]
182#[function("less_than(timestamptz, timestamptz) -> boolean")]
183#[function("less_than(date, timestamp) -> boolean")]
184#[function("less_than(timestamp, date) -> boolean")]
185#[function("less_than(time, interval) -> boolean")]
186#[function("less_than(interval, time) -> boolean")]
187#[function("less_than(varchar, varchar) -> boolean")]
188#[function("less_than(bytea, bytea) -> boolean")]
189#[function("less_than(anyarray, anyarray) -> boolean")]
190#[function("less_than(struct, struct) -> boolean")]
191pub fn general_lt<T1, T2, T3>(l: T1, r: T2) -> bool
192where
193    T1: Into<T3> + Debug,
194    T2: Into<T3> + Debug,
195    T3: Ord,
196{
197    l.into() < r.into()
198}
199
200#[function(
201    "is_distinct_from(boolean, boolean) -> boolean",
202    batch_fn = "boolarray_is_distinct_from"
203)]
204#[function("is_distinct_from(*int, *int) -> boolean")]
205#[function("is_distinct_from(decimal, decimal) -> boolean")]
206#[function("is_distinct_from(*float, *float) -> boolean")]
207#[function("is_distinct_from(serial, serial) -> boolean")]
208#[function("is_distinct_from(int256, int256) -> boolean")]
209#[function("is_distinct_from(date, date) -> boolean")]
210#[function("is_distinct_from(time, time) -> boolean")]
211#[function("is_distinct_from(interval, interval) -> boolean")]
212#[function("is_distinct_from(timestamp, timestamp) -> boolean")]
213#[function("is_distinct_from(timestamptz, timestamptz) -> boolean")]
214#[function("is_distinct_from(date, timestamp) -> boolean")]
215#[function("is_distinct_from(timestamp, date) -> boolean")]
216#[function("is_distinct_from(time, interval) -> boolean")]
217#[function("is_distinct_from(interval, time) -> boolean")]
218#[function("is_distinct_from(varchar, varchar) -> boolean")]
219#[function("is_distinct_from(bytea, bytea) -> boolean")]
220#[function("is_distinct_from(anyarray, anyarray) -> boolean")]
221#[function("is_distinct_from(struct, struct) -> boolean")]
222pub fn general_is_distinct_from<T1, T2, T3>(l: Option<T1>, r: Option<T2>) -> bool
223where
224    T1: Into<T3> + Debug,
225    T2: Into<T3> + Debug,
226    T3: Ord,
227{
228    l.map(Into::into) != r.map(Into::into)
229}
230
231#[function(
232    "is_not_distinct_from(boolean, boolean) -> boolean",
233    batch_fn = "boolarray_is_not_distinct_from"
234)]
235#[function("is_not_distinct_from(*int, *int) -> boolean")]
236#[function("is_not_distinct_from(decimal, decimal) -> boolean")]
237#[function("is_not_distinct_from(*float, *float) -> boolean")]
238#[function("is_not_distinct_from(serial, serial) -> boolean")]
239#[function("is_not_distinct_from(int256, int256) -> boolean")]
240#[function("is_not_distinct_from(date, date) -> boolean")]
241#[function("is_not_distinct_from(time, time) -> boolean")]
242#[function("is_not_distinct_from(interval, interval) -> boolean")]
243#[function("is_not_distinct_from(timestamp, timestamp) -> boolean")]
244#[function("is_not_distinct_from(timestamptz, timestamptz) -> boolean")]
245#[function("is_not_distinct_from(date, timestamp) -> boolean")]
246#[function("is_not_distinct_from(timestamp, date) -> boolean")]
247#[function("is_not_distinct_from(time, interval) -> boolean")]
248#[function("is_not_distinct_from(interval, time) -> boolean")]
249#[function("is_not_distinct_from(varchar, varchar) -> boolean")]
250#[function("is_not_distinct_from(bytea, bytea) -> boolean")]
251#[function("is_not_distinct_from(anyarray, anyarray) -> boolean")]
252#[function("is_not_distinct_from(struct, struct) -> boolean")]
253pub fn general_is_not_distinct_from<T1, T2, T3>(l: Option<T1>, r: Option<T2>) -> bool
254where
255    T1: Into<T3> + Debug,
256    T2: Into<T3> + Debug,
257    T3: Ord,
258{
259    l.map(Into::into) == r.map(Into::into)
260}
261
262#[function("is_true(boolean) -> boolean", batch_fn = "boolarray_is_true")]
263pub fn is_true(v: Option<bool>) -> bool {
264    v == Some(true)
265}
266
267#[function("is_not_true(boolean) -> boolean", batch_fn = "boolarray_is_not_true")]
268pub fn is_not_true(v: Option<bool>) -> bool {
269    v != Some(true)
270}
271
272#[function("is_false(boolean) -> boolean", batch_fn = "boolarray_is_false")]
273pub fn is_false(v: Option<bool>) -> bool {
274    v == Some(false)
275}
276
277#[function(
278    "is_not_false(boolean) -> boolean",
279    batch_fn = "boolarray_is_not_false"
280)]
281pub fn is_not_false(v: Option<bool>) -> bool {
282    v != Some(false)
283}
284
285#[function("is_null(*) -> boolean", batch_fn = "batch_is_null")]
286fn is_null<T>(v: Option<T>) -> bool {
287    v.is_none()
288}
289
290#[function("is_not_null(*) -> boolean", batch_fn = "batch_is_not_null")]
291fn is_not_null<T>(v: Option<T>) -> bool {
292    v.is_some()
293}
294
295#[function("greatest(...) -> boolean")]
296#[function("greatest(...) -> *int")]
297#[function("greatest(...) -> decimal")]
298#[function("greatest(...) -> *float")]
299#[function("greatest(...) -> serial")]
300#[function("greatest(...) -> int256")]
301#[function("greatest(...) -> date")]
302#[function("greatest(...) -> time")]
303#[function("greatest(...) -> interval")]
304#[function("greatest(...) -> timestamp")]
305#[function("greatest(...) -> timestamptz")]
306#[function("greatest(...) -> varchar")]
307#[function("greatest(...) -> bytea")]
308pub fn general_variadic_greatest<T>(row: impl Row) -> Option<T>
309where
310    T: Scalar,
311    for<'a> <T as Scalar>::ScalarRefType<'a>: TryFrom<ScalarRefImpl<'a>> + Ord + Debug,
312{
313    row.iter()
314        .flatten()
315        .map(
316            |scalar| match <<T as Scalar>::ScalarRefType<'_>>::try_from(scalar) {
317                Ok(v) => v,
318                Err(_) => unreachable!("all input type should have been aligned in the frontend"),
319            },
320        )
321        .max()
322        .map(|v| v.to_owned_scalar())
323}
324
325#[function("least(...) -> boolean")]
326#[function("least(...) -> *int")]
327#[function("least(...) -> decimal")]
328#[function("least(...) -> *float")]
329#[function("least(...) -> serial")]
330#[function("least(...) -> int256")]
331#[function("least(...) -> date")]
332#[function("least(...) -> time")]
333#[function("least(...) -> interval")]
334#[function("least(...) -> timestamp")]
335#[function("least(...) -> timestamptz")]
336#[function("least(...) -> varchar")]
337#[function("least(...) -> bytea")]
338pub fn general_variadic_least<T>(row: impl Row) -> Option<T>
339where
340    T: Scalar,
341    for<'a> <T as Scalar>::ScalarRefType<'a>: TryFrom<ScalarRefImpl<'a>> + Ord + Debug,
342{
343    row.iter()
344        .flatten()
345        .map(
346            |scalar| match <<T as Scalar>::ScalarRefType<'_>>::try_from(scalar) {
347                Ok(v) => v,
348                Err(_) => unreachable!("all input type should have been aligned in the frontend"),
349            },
350        )
351        .min()
352        .map(|v| v.to_owned_scalar())
353}
354
355// optimized functions for bool arrays
356
357fn boolarray_eq(l: &BoolArray, r: &BoolArray) -> BoolArray {
358    let data = !(l.data() ^ r.data());
359    let bitmap = l.null_bitmap() & r.null_bitmap();
360    BoolArray::new(data, bitmap)
361}
362
363fn boolarray_ne(l: &BoolArray, r: &BoolArray) -> BoolArray {
364    let data = l.data() ^ r.data();
365    let bitmap = l.null_bitmap() & r.null_bitmap();
366    BoolArray::new(data, bitmap)
367}
368
369fn boolarray_gt(l: &BoolArray, r: &BoolArray) -> BoolArray {
370    let data = l.data() & !r.data();
371    let bitmap = l.null_bitmap() & r.null_bitmap();
372    BoolArray::new(data, bitmap)
373}
374
375fn boolarray_lt(l: &BoolArray, r: &BoolArray) -> BoolArray {
376    let data = !l.data() & r.data();
377    let bitmap = l.null_bitmap() & r.null_bitmap();
378    BoolArray::new(data, bitmap)
379}
380
381fn boolarray_ge(l: &BoolArray, r: &BoolArray) -> BoolArray {
382    let data = l.data() | !r.data();
383    let bitmap = l.null_bitmap() & r.null_bitmap();
384    BoolArray::new(data, bitmap)
385}
386
387fn boolarray_le(l: &BoolArray, r: &BoolArray) -> BoolArray {
388    let data = !l.data() | r.data();
389    let bitmap = l.null_bitmap() & r.null_bitmap();
390    BoolArray::new(data, bitmap)
391}
392
393fn boolarray_is_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray {
394    let data = ((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap()))
395        | (l.null_bitmap() ^ r.null_bitmap());
396    BoolArray::new(data, Bitmap::ones(l.len()))
397}
398
399fn boolarray_is_not_distinct_from(l: &BoolArray, r: &BoolArray) -> BoolArray {
400    let data = !(((l.data() ^ r.data()) & (l.null_bitmap() & r.null_bitmap()))
401        | (l.null_bitmap() ^ r.null_bitmap()));
402    BoolArray::new(data, Bitmap::ones(l.len()))
403}
404
405fn boolarray_is_true(a: &BoolArray) -> BoolArray {
406    BoolArray::new(a.to_bitmap(), Bitmap::ones(a.len()))
407}
408
409fn boolarray_is_not_true(a: &BoolArray) -> BoolArray {
410    BoolArray::new(!a.to_bitmap(), Bitmap::ones(a.len()))
411}
412
413fn boolarray_is_false(a: &BoolArray) -> BoolArray {
414    BoolArray::new(!a.data() & a.null_bitmap(), Bitmap::ones(a.len()))
415}
416
417fn boolarray_is_not_false(a: &BoolArray) -> BoolArray {
418    BoolArray::new(a.data() | !a.null_bitmap(), Bitmap::ones(a.len()))
419}
420
421fn batch_is_null(a: &impl Array) -> BoolArray {
422    BoolArray::new(!a.null_bitmap(), Bitmap::ones(a.len()))
423}
424
425fn batch_is_not_null(a: &impl Array) -> BoolArray {
426    BoolArray::new(a.null_bitmap().clone(), Bitmap::ones(a.len()))
427}
428
429#[function("secure_compare(varchar, varchar) -> boolean")]
430pub fn secure_compare(left: &str, right: &str) -> bool {
431    constant_time_eq(left.as_bytes(), right.as_bytes())
432}
433
434#[function("check_not_null(any, varchar, varchar) -> any")]
435fn check_not_null<'a>(
436    v: Option<ScalarRefImpl<'a>>,
437    col_name: &str,
438    relation_name: &str,
439) -> Result<Option<ScalarRefImpl<'a>>> {
440    if v.is_none() {
441        return Err(ExprError::NotNullViolation {
442            col_name: col_name.into(),
443            table_name: relation_name.into(),
444        });
445    }
446    Ok(v)
447}
448
449#[cfg(test)]
450mod tests {
451    use std::str::FromStr;
452
453    use risingwave_common::types::{Decimal, F32, F64, Timestamp};
454    use risingwave_expr::expr::build_from_pretty;
455
456    use super::*;
457
458    #[test]
459    fn test_comparison() {
460        assert!(general_eq::<Decimal, i32, Decimal>(dec("1.0"), 1));
461        assert!(!general_ne::<Decimal, i32, Decimal>(dec("1.0"), 1));
462        assert!(!general_gt::<Decimal, i32, Decimal>(dec("1.0"), 2));
463        assert!(general_le::<Decimal, i32, Decimal>(dec("1.0"), 2));
464        assert!(!general_ge::<Decimal, i32, Decimal>(dec("1.0"), 2));
465        assert!(general_lt::<Decimal, i32, Decimal>(dec("1.0"), 2));
466        assert!(general_is_distinct_from::<Decimal, i32, Decimal>(
467            Some(dec("1.0")),
468            Some(2)
469        ));
470        assert!(general_is_distinct_from::<Decimal, i32, Decimal>(
471            None,
472            Some(1)
473        ));
474        assert!(!general_is_distinct_from::<Decimal, i32, Decimal>(
475            Some(dec("1.0")),
476            Some(1)
477        ));
478        assert!(general_eq::<F32, i32, F64>(1.0.into(), 1));
479        assert!(!general_ne::<F32, i32, F64>(1.0.into(), 1));
480        assert!(!general_lt::<F32, i32, F64>(1.0.into(), 1));
481        assert!(general_le::<F32, i32, F64>(1.0.into(), 1));
482        assert!(!general_gt::<F32, i32, F64>(1.0.into(), 1));
483        assert!(general_ge::<F32, i32, F64>(1.0.into(), 1));
484        assert!(!general_is_distinct_from::<F32, i32, F64>(
485            Some(1.0.into()),
486            Some(1)
487        ));
488        assert!(general_eq::<i64, i32, i64>(1i64, 1));
489        assert!(!general_ne::<i64, i32, i64>(1i64, 1));
490        assert!(!general_lt::<i64, i32, i64>(1i64, 1));
491        assert!(general_le::<i64, i32, i64>(1i64, 1));
492        assert!(!general_gt::<i64, i32, i64>(1i64, 1));
493        assert!(general_ge::<i64, i32, i64>(1i64, 1));
494        assert!(!general_is_distinct_from::<i64, i32, i64>(
495            Some(1i64),
496            Some(1)
497        ));
498    }
499
500    fn dec(s: &str) -> Decimal {
501        Decimal::from_str(s).unwrap()
502    }
503
504    #[tokio::test]
505    async fn test_is_distinct_from() {
506        let (input, target) = DataChunk::from_pretty(
507            "
508            i i B
509            . . f
510            . 1 t
511            1 . t
512            2 2 f
513            3 4 t
514        ",
515        )
516        .split_column_at(2);
517        let expr = build_from_pretty("(is_distinct_from:boolean $0:int4 $1:int4)");
518        let result = expr.eval(&input).await.unwrap();
519        assert_eq!(&result, target.column_at(0));
520    }
521
522    #[tokio::test]
523    async fn test_is_not_distinct_from() {
524        let (input, target) = DataChunk::from_pretty(
525            "
526            i i B
527            . . t
528            . 1 f
529            1 . f
530            2 2 t
531            3 4 f
532            ",
533        )
534        .split_column_at(2);
535        let expr = build_from_pretty("(is_not_distinct_from:boolean $0:int4 $1:int4)");
536        let result = expr.eval(&input).await.unwrap();
537        assert_eq!(&result, target.column_at(0));
538    }
539
540    use risingwave_common::array::*;
541    use risingwave_common::row::OwnedRow;
542    use risingwave_common::types::test_utils::IntervalTestExt;
543    use risingwave_common::types::{Date, Interval};
544    use risingwave_pb::expr::expr_node::Type;
545
546    use crate::scalar::arithmetic_op::{date_interval_add, date_interval_sub};
547
548    #[tokio::test]
549    async fn test_binary() {
550        test_binary_i32::<I32Array, _>(|x, y| x + y, Type::Add).await;
551        test_binary_i32::<I32Array, _>(|x, y| x - y, Type::Subtract).await;
552        test_binary_i32::<I32Array, _>(|x, y| x * y, Type::Multiply).await;
553        test_binary_i32::<I32Array, _>(|x, y| x / y, Type::Divide).await;
554        test_binary_i32::<BoolArray, _>(|x, y| x == y, Type::Equal).await;
555        test_binary_i32::<BoolArray, _>(|x, y| x != y, Type::NotEqual).await;
556        test_binary_i32::<BoolArray, _>(|x, y| x > y, Type::GreaterThan).await;
557        test_binary_i32::<BoolArray, _>(|x, y| x >= y, Type::GreaterThanOrEqual).await;
558        test_binary_i32::<BoolArray, _>(|x, y| x < y, Type::LessThan).await;
559        test_binary_i32::<BoolArray, _>(|x, y| x <= y, Type::LessThanOrEqual).await;
560        test_binary_inner::<I32Array, I32Array, I32Array, _>(
561            reduce(std::cmp::max::<i32>),
562            Type::Greatest,
563        )
564        .await;
565        test_binary_inner::<I32Array, I32Array, I32Array, _>(
566            reduce(std::cmp::min::<i32>),
567            Type::Least,
568        )
569        .await;
570        test_binary_inner::<BoolArray, BoolArray, BoolArray, _>(
571            reduce(std::cmp::max::<bool>),
572            Type::Greatest,
573        )
574        .await;
575        test_binary_inner::<BoolArray, BoolArray, BoolArray, _>(
576            reduce(std::cmp::min::<bool>),
577            Type::Least,
578        )
579        .await;
580        test_binary_decimal::<DecimalArray, _>(|x, y| x + y, Type::Add).await;
581        test_binary_decimal::<DecimalArray, _>(|x, y| x - y, Type::Subtract).await;
582        test_binary_decimal::<DecimalArray, _>(|x, y| x * y, Type::Multiply).await;
583        test_binary_decimal::<DecimalArray, _>(|x, y| x / y, Type::Divide).await;
584        test_binary_decimal::<BoolArray, _>(|x, y| x == y, Type::Equal).await;
585        test_binary_decimal::<BoolArray, _>(|x, y| x != y, Type::NotEqual).await;
586        test_binary_decimal::<BoolArray, _>(|x, y| x > y, Type::GreaterThan).await;
587        test_binary_decimal::<BoolArray, _>(|x, y| x >= y, Type::GreaterThanOrEqual).await;
588        test_binary_decimal::<BoolArray, _>(|x, y| x < y, Type::LessThan).await;
589        test_binary_decimal::<BoolArray, _>(|x, y| x <= y, Type::LessThanOrEqual).await;
590        test_binary_interval::<TimestampArray, _>(
591            |x, y| date_interval_add(x, y).unwrap(),
592            Type::Add,
593        )
594        .await;
595        test_binary_interval::<TimestampArray, _>(
596            |x, y| date_interval_sub(x, y).unwrap(),
597            Type::Subtract,
598        )
599        .await;
600    }
601
602    trait TestFrom: Copy {
603        const NAME: &'static str;
604        fn test_from(i: usize) -> Self;
605    }
606
607    impl TestFrom for i32 {
608        const NAME: &'static str = "int4";
609
610        fn test_from(i: usize) -> Self {
611            i as i32
612        }
613    }
614
615    impl TestFrom for Decimal {
616        const NAME: &'static str = "decimal";
617
618        fn test_from(i: usize) -> Self {
619            i.into()
620        }
621    }
622
623    impl TestFrom for bool {
624        const NAME: &'static str = "boolean";
625
626        fn test_from(i: usize) -> Self {
627            i % 2 == 0
628        }
629    }
630
631    impl TestFrom for Timestamp {
632        const NAME: &'static str = "timestamp";
633
634        fn test_from(_: usize) -> Self {
635            unimplemented!("not implemented as input yet")
636        }
637    }
638
639    impl TestFrom for Interval {
640        const NAME: &'static str = "interval";
641
642        fn test_from(i: usize) -> Self {
643            Interval::from_ymd(0, i as _, i as _)
644        }
645    }
646
647    impl TestFrom for Date {
648        const NAME: &'static str = "date";
649
650        fn test_from(i: usize) -> Self {
651            Date::from_num_days_from_ce_uncheck(i as i32)
652        }
653    }
654
655    #[expect(clippy::type_complexity)]
656    fn gen_test_data<L: TestFrom, R: TestFrom, O>(
657        count: usize,
658        f: impl Fn(Option<L>, Option<R>) -> Option<O>,
659    ) -> (Vec<Option<L>>, Vec<Option<R>>, Vec<Option<O>>) {
660        let mut lhs = Vec::<Option<L>>::new();
661        let mut rhs = Vec::<Option<R>>::new();
662        let mut target = Vec::<Option<O>>::new();
663        for i in 0..count {
664            let (l, r) = if i % 2 == 0 {
665                (Some(i), None)
666            } else if i % 3 == 0 {
667                (Some(i), Some(i + 1))
668            } else if i % 5 == 0 {
669                (Some(i + 1), Some(i))
670            } else if i % 7 == 0 {
671                (None, Some(i))
672            } else {
673                (Some(i), Some(i))
674            };
675            let l = l.map(TestFrom::test_from);
676            let r = r.map(TestFrom::test_from);
677            lhs.push(l);
678            rhs.push(r);
679            target.push(f(l, r));
680        }
681        (lhs, rhs, target)
682    }
683
684    fn arithmetic<L, R, O>(f: impl Fn(L, R) -> O) -> impl Fn(Option<L>, Option<R>) -> Option<O> {
685        move |l, r| match (l, r) {
686            (Some(l), Some(r)) => Some(f(l, r)),
687            _ => None,
688        }
689    }
690
691    fn reduce<I>(f: impl Fn(I, I) -> I) -> impl Fn(Option<I>, Option<I>) -> Option<I> {
692        move |l, r| match (l, r) {
693            (Some(l), Some(r)) => Some(f(l, r)),
694            (Some(l), None) => Some(l),
695            (None, Some(r)) => Some(r),
696            (None, None) => None,
697        }
698    }
699
700    async fn test_binary_inner<L, R, A, F>(f: F, kind: Type)
701    where
702        L: Array,
703        L: for<'a> FromIterator<&'a Option<<L as Array>::OwnedItem>>,
704        <L as Array>::OwnedItem: TestFrom,
705        R: Array,
706        R: for<'a> FromIterator<&'a Option<<R as Array>::OwnedItem>>,
707        <R as Array>::OwnedItem: TestFrom,
708        A: Array,
709        for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
710        for<'a> <A as Array>::RefItem<'a>: PartialEq,
711        <A as Array>::OwnedItem: TestFrom,
712        F: Fn(
713            Option<<L as Array>::OwnedItem>,
714            Option<<R as Array>::OwnedItem>,
715        ) -> Option<<A as Array>::OwnedItem>,
716    {
717        let (lhs, rhs, target) = gen_test_data(100, f);
718
719        let col1 = L::from_iter(&lhs).into_ref();
720        let col2 = R::from_iter(&rhs).into_ref();
721        let data_chunk = DataChunk::new(vec![col1, col2], 100);
722        let l_name = <<L as Array>::OwnedItem as TestFrom>::NAME;
723        let r_name = <<R as Array>::OwnedItem as TestFrom>::NAME;
724        let output_name = <<A as Array>::OwnedItem as TestFrom>::NAME;
725        let expr = build_from_pretty(format!(
726            "({name}:{output_name} $0:{l_name} $1:{r_name})",
727            name = kind.as_str_name(),
728        ));
729        let res = expr.eval(&data_chunk).await.unwrap();
730        let arr: &A = res.as_ref().into();
731        for (idx, item) in arr.iter().enumerate() {
732            let x = target[idx].as_ref().map(|x| x.as_scalar_ref());
733            assert_eq!(x, item);
734        }
735
736        for i in 0..lhs.len() {
737            let row = OwnedRow::new(vec![
738                lhs[i].map(|int| int.to_scalar_value()),
739                rhs[i].map(|int| int.to_scalar_value()),
740            ]);
741            let result = expr.eval_row(&row).await.unwrap();
742            let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value());
743            assert_eq!(result, expected);
744        }
745    }
746
747    async fn test_binary_i32<A, F>(f: F, kind: Type)
748    where
749        A: Array,
750        for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
751        for<'a> <A as Array>::RefItem<'a>: PartialEq,
752        <A as Array>::OwnedItem: TestFrom,
753        F: Fn(i32, i32) -> <A as Array>::OwnedItem,
754    {
755        test_binary_inner::<I32Array, I32Array, _, _>(arithmetic(f), kind).await
756    }
757
758    async fn test_binary_interval<A, F>(f: F, kind: Type)
759    where
760        A: Array,
761        for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
762        for<'a> <A as Array>::RefItem<'a>: PartialEq,
763        <A as Array>::OwnedItem: TestFrom,
764        F: Fn(Date, Interval) -> <A as Array>::OwnedItem,
765    {
766        test_binary_inner::<DateArray, IntervalArray, _, _>(arithmetic(f), kind).await
767    }
768
769    async fn test_binary_decimal<A, F>(f: F, kind: Type)
770    where
771        A: Array,
772        for<'a> &'a A: std::convert::From<&'a ArrayImpl>,
773        for<'a> <A as Array>::RefItem<'a>: PartialEq,
774        <A as Array>::OwnedItem: TestFrom,
775        F: Fn(Decimal, Decimal) -> <A as Array>::OwnedItem,
776    {
777        test_binary_inner::<DecimalArray, DecimalArray, _, _>(arithmetic(f), kind).await
778    }
779}