risingwave_expr_impl/scalar/
array_contain.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Range expression functions.
16
17use std::collections::HashSet;
18
19use risingwave_common::types::ListRef;
20use risingwave_expr::function;
21
22/// Returns whether left range contains right range.
23///
24/// Examples:
25///
26/// ```slt
27/// query I
28/// select array[1,2,3] @> array[2,3];
29/// ----
30/// t
31///
32/// query I
33/// select array[1,2,3] @> array[3,4];
34/// ----
35/// f
36///
37/// query I
38/// SELECT array[1,2,3] @> array[3,1];
39/// ----
40/// t
41///
42/// query I
43/// SELECT array[1,2] @> array[1,1];
44/// ----
45/// t
46///
47/// query I
48/// SELECT array[1,2,3] @> array[]::int[];
49/// ----
50/// t
51///
52/// query I
53/// SELECT ARRAY[]::int[] @> ARRAY[]::int[];
54/// ----
55/// t
56///
57/// query I
58/// select array[[[1,2],[3,4]],[[5,6],[7,8]]] @> array[2,3];
59/// ----
60/// t
61///
62/// query I
63/// select array[1,2,3] @> null;
64/// ----
65/// NULL
66///
67/// query I
68/// select null @> array[3,4];
69/// ----
70/// NULL
71///
72/// query I
73/// select array[1,null,2] @> array[1,null,2];
74/// ----
75/// f
76///
77/// query I
78/// select array[1,null,2] @> array[1,2];
79/// ----
80/// t
81///
82/// query I
83/// SELECT array[1,NULL,2] @> array[NULL]::int[];
84/// ----
85/// f
86///
87/// query I
88/// SELECT NULL::int[] @> ARRAY[1];
89/// ----
90/// NULL
91/// ```
92fn array_contains_impl(left: ListRef<'_>, right: ListRef<'_>) -> bool {
93    let flatten = left.flatten();
94    let set: HashSet<_> = flatten.iter().collect();
95    right
96        .flatten()
97        .iter()
98        .all(|item| item.is_some_and(|v| set.contains(&Some(v))))
99}
100
101#[function("array_contains(anyarray, anyarray) -> boolean")]
102fn array_contains(left: ListRef<'_>, right: ListRef<'_>) -> bool {
103    array_contains_impl(left, right)
104}
105
106#[function("array_contained(anyarray, anyarray) -> boolean")]
107fn array_contained(left: ListRef<'_>, right: ListRef<'_>) -> bool {
108    array_contains_impl(right, left)
109}
110
111fn array_overlaps_impl(left: ListRef<'_>, right: ListRef<'_>) -> bool {
112    let flatten = left.flatten();
113    let set: HashSet<_> = flatten.iter().flatten().collect();
114    right
115        .flatten()
116        .iter()
117        .any(|item| item.is_some_and(|v| set.contains(&v)))
118}
119
120#[function("array_overlaps(anyarray, anyarray) -> boolean")]
121fn array_overlaps(left: ListRef<'_>, right: ListRef<'_>) -> bool {
122    array_overlaps_impl(left, right)
123}
124
125#[cfg(test)]
126mod tests {
127    use risingwave_common::types::{DataType, ListValue, Scalar, ScalarImpl};
128
129    use super::*;
130
131    #[test]
132    fn test_contains() {
133        assert!(array_contains_impl(
134            ListValue::from_iter([2, 3]).as_scalar_ref(),
135            ListValue::from_iter([2]).as_scalar_ref(),
136        ));
137        assert!(!array_contains_impl(
138            ListValue::from_iter([2, 3]).as_scalar_ref(),
139            ListValue::from_iter([5]).as_scalar_ref(),
140        ));
141    }
142
143    #[test]
144    fn test_overlaps() {
145        assert!(array_overlaps_impl(
146            ListValue::from_iter([2, 3]).as_scalar_ref(),
147            ListValue::from_iter([3, 5]).as_scalar_ref(),
148        ));
149        assert!(!array_overlaps_impl(
150            ListValue::from_iter([2, 3]).as_scalar_ref(),
151            ListValue::from_iter([4, 5]).as_scalar_ref(),
152        ));
153        assert!(!array_overlaps_impl(
154            ListValue::from_datum_iter(
155                &DataType::Int32,
156                [Some(ScalarImpl::Int32(1)), None::<ScalarImpl>],
157            )
158            .as_scalar_ref(),
159            ListValue::from_datum_iter(
160                &DataType::Int32,
161                [None::<ScalarImpl>, Some(ScalarImpl::Int32(2))],
162            )
163            .as_scalar_ref(),
164        ));
165        assert!(!array_overlaps_impl(
166            ListValue::from_datum_iter(&DataType::Int32, [None::<ScalarImpl>]).as_scalar_ref(),
167            ListValue::from_datum_iter(&DataType::Int32, [None::<ScalarImpl>]).as_scalar_ref(),
168        ));
169    }
170}