risingwave_expr_impl/scalar/array_positions.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::{I32Array, ListRef, ListValue};
16use risingwave_common::types::ScalarRefImpl;
17use risingwave_expr::{ExprError, Result, function};
18
19/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if
20/// it's not present.
21///
22/// Examples:
23///
24/// ```slt
25/// query I
26/// select array_position(array[1, null, 2, null], null);
27/// ----
28/// 2
29///
30/// query I
31/// select array_position(array[3, 4, 5], 2);
32/// ----
33/// NULL
34///
35/// query I
36/// select array_position(null, 4);
37/// ----
38/// NULL
39///
40/// query I
41/// select array_position(null, null);
42/// ----
43/// NULL
44///
45/// query I
46/// select array_position('{yes}', true);
47/// ----
48/// 1
49///
50/// # Like in PostgreSQL, searching `int` in multidimensional array is disallowed.
51/// statement error
52/// select array_position(array[array[1, 2], array[3, 4]], 1);
53///
54/// # Unlike in PostgreSQL, it is okay to search `int[]` inside `int[][]`.
55/// query I
56/// select array_position(array[array[1, 2], array[3, 4]], array[3, 4]);
57/// ----
58/// 2
59///
60/// statement error
61/// select array_position(array[3, 4], true);
62///
63/// query I
64/// select array_position(array[3, 4], 4.0);
65/// ----
66/// 2
67/// ```
68#[function("array_position(anyarray, any) -> int4")]
69pub(super) fn array_position(
70 array: ListRef<'_>,
71 element: Option<ScalarRefImpl<'_>>,
72) -> Result<Option<i32>> {
73 array_position_common(array, element, 0)
74}
75
76/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if
77/// it's not present. The search begins at the third argument.
78///
79/// Examples:
80///
81/// ```slt
82/// statement error
83/// select array_position(array[1, null, 2, null], null, false);
84///
85/// statement error
86/// select array_position(array[1, null, 2, null], null, null::int);
87///
88/// query II
89/// select v, array_position(array[1, null, 2, null], null, v) from generate_series(-1, 5) as t(v);
90/// ----
91/// -1 2
92/// 0 2
93/// 1 2
94/// 2 2
95/// 3 4
96/// 4 4
97/// 5 NULL
98/// ```
99#[function("array_position(anyarray, any, int4) -> int4")]
100fn array_position_start(
101 array: ListRef<'_>,
102 element: Option<ScalarRefImpl<'_>>,
103 start: Option<i32>,
104) -> Result<Option<i32>> {
105 let start = match start {
106 None => {
107 return Err(ExprError::InvalidParam {
108 name: "start",
109 reason: "initial position must not be null".into(),
110 });
111 }
112 Some(start) => (start.max(1) - 1) as usize,
113 };
114 array_position_common(array, element, start)
115}
116
117fn array_position_common(
118 array: ListRef<'_>,
119 element: Option<ScalarRefImpl<'_>>,
120 skip: usize,
121) -> Result<Option<i32>> {
122 if i32::try_from(array.len()).is_err() {
123 return Err(ExprError::CastOutOfRange("invalid array length"));
124 }
125
126 Ok(array
127 .iter()
128 .skip(skip)
129 .position(|item| item == element)
130 .map(|idx| (idx + 1 + skip) as _))
131}
132
133/// Returns an array of the subscripts of all occurrences of the second argument in the array
134/// given as first argument. Note the behavior is slightly different from PG.
135///
136/// Examples:
137///
138/// ```slt
139/// query T
140/// select array_positions(array[array[1],array[2],array[3],array[2],null], array[1]);
141/// ----
142/// {1}
143///
144/// query T
145/// select array_positions(array[array[1],array[2],array[3],array[2],null], array[2]);
146/// ----
147/// {2,4}
148///
149/// query T
150/// select array_positions(array[array[1],array[2],array[3],array[2],null], null);
151/// ----
152/// {5}
153///
154/// query T
155/// select array_positions(array[array[1],array[2],array[3],array[2],null], array[4]);
156/// ----
157/// {}
158///
159/// query T
160/// select array_positions(null, 1);
161/// ----
162/// NULL
163///
164/// query T
165/// select array_positions(ARRAY[array[1],array[2],array[3],array[2],null], array[3.14]);
166/// ----
167/// {}
168///
169/// query T
170/// select array_positions(array[1,NULL,NULL,3], NULL);
171/// ----
172/// {2,3}
173///
174/// statement error
175/// select array_positions(array[array[1],array[2],array[3],array[2],null], 1);
176///
177/// statement error
178/// select array_positions(array[array[1],array[2],array[3],array[2],null], array[array[3]]);
179///
180/// statement error
181/// select array_positions(ARRAY[array[1],array[2],array[3],array[2],null], array[true]);
182/// ```
183#[function("array_positions(anyarray, any) -> int4[]")]
184fn array_positions(
185 array: Option<ListRef<'_>>,
186 element: Option<ScalarRefImpl<'_>>,
187) -> Result<Option<ListValue>> {
188 let Some(array) = array else {
189 return Ok(None);
190 };
191 let values = array.iter();
192 if values.len() - 1 > i32::MAX as usize {
193 return Err(ExprError::CastOutOfRange("invalid array length"));
194 }
195 Ok(Some(ListValue::new(
196 values
197 .enumerate()
198 .filter(|(_, item)| item == &element)
199 .map(|(idx, _)| idx as i32 + 1)
200 .collect::<I32Array>()
201 .into(),
202 )))
203}