risingwave_expr_impl/scalar/array_contain.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Range expression functions.
16
17use std::collections::HashSet;
18
19use risingwave_common::types::ListRef;
20use risingwave_expr::function;
21
22/// Returns whether left range contains right range.
23///
24/// Examples:
25///
26/// ```slt
27/// query I
28/// select array[1,2,3] @> array[2,3];
29/// ----
30/// t
31///
32/// query I
33/// select array[1,2,3] @> array[3,4];
34/// ----
35/// f
36///
37/// query I
38/// select array[[[1,2],[3,4]],[[5,6],[7,8]]] @> array[2,3];
39/// ----
40/// t
41///
42/// query I
43/// select array[1,2,3] @> null;
44/// ----
45/// NULL
46///
47/// query I
48/// select null @> array[3,4];
49/// ----
50/// NULL
51/// ```
52#[function("array_contains(anyarray, anyarray) -> boolean")]
53fn array_contains(left: ListRef<'_>, right: ListRef<'_>) -> bool {
54 let flatten = left.flatten();
55 let set: HashSet<_> = flatten.iter().collect();
56 right.flatten().iter().all(|item| set.contains(&item))
57}
58
59#[function("array_contained(anyarray, anyarray) -> boolean")]
60fn array_contained(left: ListRef<'_>, right: ListRef<'_>) -> bool {
61 array_contains(right, left)
62}
63
64#[cfg(test)]
65mod tests {
66 use risingwave_common::types::{ListValue, Scalar};
67
68 use super::*;
69
70 #[test]
71 fn test_contains() {
72 assert!(array_contains(
73 ListValue::from_iter([2, 3]).as_scalar_ref(),
74 ListValue::from_iter([2]).as_scalar_ref(),
75 ));
76 assert!(!array_contains(
77 ListValue::from_iter([2, 3]).as_scalar_ref(),
78 ListValue::from_iter([5]).as_scalar_ref(),
79 ));
80 }
81}