risingwave_expr_impl/scalar/
array_contain.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Range expression functions.
16
17use std::collections::HashSet;
18
19use risingwave_common::types::ListRef;
20use risingwave_expr::function;
21
22/// Returns whether left range contains right range.
23///
24/// Examples:
25///
26/// ```slt
27/// query I
28/// select array[1,2,3] @> array[2,3];
29/// ----
30/// t
31///
32/// query I
33/// select array[1,2,3] @> array[3,4];
34/// ----
35/// f
36///
37/// query I
38/// select array[[[1,2],[3,4]],[[5,6],[7,8]]] @> array[2,3];
39/// ----
40/// t
41///
42/// query I
43/// select array[1,2,3] @> null;
44/// ----
45/// NULL
46///
47/// query I
48/// select null @> array[3,4];
49/// ----
50/// NULL
51/// ```
52#[function("array_contains(anyarray, anyarray) -> boolean")]
53fn array_contains(left: ListRef<'_>, right: ListRef<'_>) -> bool {
54    let flatten = left.flatten();
55    let set: HashSet<_> = flatten.iter().collect();
56    right.flatten().iter().all(|item| set.contains(&item))
57}
58
59#[function("array_contained(anyarray, anyarray) -> boolean")]
60fn array_contained(left: ListRef<'_>, right: ListRef<'_>) -> bool {
61    array_contains(right, left)
62}
63
64#[cfg(test)]
65mod tests {
66    use risingwave_common::types::{ListValue, Scalar};
67
68    use super::*;
69
70    #[test]
71    fn test_contains() {
72        assert!(array_contains(
73            ListValue::from_iter([2, 3]).as_scalar_ref(),
74            ListValue::from_iter([2]).as_scalar_ref(),
75        ));
76        assert!(!array_contains(
77            ListValue::from_iter([2, 3]).as_scalar_ref(),
78            ListValue::from_iter([5]).as_scalar_ref(),
79        ));
80    }
81}