risingwave_expr_impl/scalar/array.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::{ListValue, StructValue};
16use risingwave_common::row::Row;
17use risingwave_common::types::{
18 DataType, ListRef, MapRef, MapType, MapValue, ScalarRef, ScalarRefImpl, ToOwnedDatum,
19};
20use risingwave_expr::expr::Context;
21use risingwave_expr::{ExprError, function};
22
23use super::array_positions::array_position;
24
25#[function("array(...) -> anyarray", type_infer = "unreachable")]
26fn array(row: impl Row, ctx: &Context) -> ListValue {
27 ListValue::from_datum_iter(ctx.return_type.as_list_elem(), row.iter())
28}
29
30#[function("row(...) -> struct", type_infer = "unreachable")]
31fn row_(row: impl Row) -> StructValue {
32 StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect())
33}
34
35fn map_from_key_values_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
36 let map = MapType::try_from_kv(
37 args[0].as_list_elem().clone(),
38 args[1].as_list_elem().clone(),
39 )
40 .map_err(ExprError::Custom)?;
41 Ok(map.into())
42}
43
44fn map_from_entries_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
45 let map =
46 MapType::try_from_entries(args[0].as_list_elem().clone()).map_err(ExprError::Custom)?;
47 Ok(map.into())
48}
49
50/// # Example
51///
52/// ```slt
53/// query T
54/// select map_from_key_values(null::int[], array[1,2,3]);
55/// ----
56/// NULL
57///
58/// query T
59/// select map_from_key_values(array['a','b','c'], array[1,2,3]);
60/// ----
61/// {a:1,b:2,c:3}
62/// ```
63#[function(
64 "map_from_key_values(anyarray, anyarray) -> anymap",
65 type_infer = "map_from_key_values_type_infer"
66)]
67fn map_from_key_values(keys: ListRef<'_>, values: ListRef<'_>) -> Result<MapValue, ExprError> {
68 MapValue::try_from_kv(keys.to_owned_scalar(), values.to_owned_scalar())
69 .map_err(ExprError::Custom)
70}
71
72#[function(
73 "map_from_entries(anyarray) -> anymap",
74 type_infer = "map_from_entries_type_infer"
75)]
76fn map_from_entries(entries: ListRef<'_>) -> Result<MapValue, ExprError> {
77 MapValue::try_from_entries(entries.to_owned_scalar()).map_err(ExprError::Custom)
78}
79
80/// # Example
81///
82/// ```slt
83/// query T
84/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 3);
85/// ----
86/// 300
87///
88/// query T
89/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), '3');
90/// ----
91/// 300
92///
93/// query error
94/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 1.0);
95/// ----
96/// db error: ERROR: Failed to run the query
97///
98/// Caused by these errors (recent errors listed first):
99/// 1: Failed to bind expression: map_access(map_from_key_values(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0)
100/// 2: Bind error: Cannot access numeric in map(integer,integer)
101///
102///
103/// query T
104/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'a');
105/// ----
106/// 1
107///
108/// query T
109/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'd');
110/// ----
111/// NULL
112///
113/// query T
114/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), null);
115/// ----
116/// NULL
117/// ```
118#[function("map_access(anymap, any) -> any", type_infer = "unreachable")]
119fn map_access<'a>(
120 map: MapRef<'a>,
121 key: ScalarRefImpl<'_>,
122) -> Result<Option<ScalarRefImpl<'a>>, ExprError> {
123 // FIXME: DatumRef in return value is not support by the macro yet.
124
125 let (keys, values) = map.into_kv();
126 let idx = array_position(keys, Some(key))?;
127 match idx {
128 Some(idx) => Ok(values.get((idx - 1) as usize).unwrap()),
129 None => Ok(None),
130 }
131}
132
133/// ```slt
134/// query T
135/// select
136/// map_contains(MAP{1:1}, 1),
137/// map_contains(MAP{1:1}, 2),
138/// map_contains(MAP{1:1}, NULL),
139/// map_contains(MAP{1:1}, '1'),
140/// map_contains(MAP{'a':'1','b':'2'}, 'ab');
141/// ----
142/// t f NULL t f
143///
144///
145/// query error
146/// select map_contains(MAP{1:1}, 1.0);
147/// ----
148/// db error: ERROR: Failed to run the query
149///
150/// Caused by these errors (recent errors listed first):
151/// 1: Failed to bind expression: map_contains(MAP {1: 1}, 1.0)
152/// 2: Bind error: Cannot check if numeric exists in map(integer,integer)
153///
154///
155/// query error
156/// select map_contains(MAP{1:1}, NULL::varchar);
157/// ----
158/// db error: ERROR: Failed to run the query
159///
160/// Caused by these errors (recent errors listed first):
161/// 1: Failed to bind expression: map_contains(MAP {1: 1}, CAST(NULL AS CHARACTER VARYING))
162/// 2: Bind error: Cannot check if character varying exists in map(integer,integer)
163/// ```
164#[function("map_contains(anymap, any) -> boolean")]
165fn map_contains(map: MapRef<'_>, key: ScalarRefImpl<'_>) -> Result<bool, ExprError> {
166 let (keys, _values) = map.into_kv();
167 let idx = array_position(keys, Some(key))?;
168 Ok(idx.is_some())
169}
170
171/// ```slt
172/// query I
173/// select
174/// map_length(NULL::map(int,int)),
175/// map_length(MAP {}::map(int,int)),
176/// map_length(MAP {1:1,2:2}::map(int,int))
177/// ----
178/// NULL 0 2
179/// ```
180#[function("map_length(anymap) -> int4")]
181fn map_length<T: TryFrom<usize>>(map: MapRef<'_>) -> Result<T, ExprError> {
182 map.len().try_into().map_err(|_| ExprError::NumericOverflow)
183}
184
185/// If both `m1` and `m2` have a value with the same key, then the output map contains the value from `m2`.
186///
187/// ```slt
188/// query T
189/// select map_cat(MAP{'a':1,'b':2},null::map(varchar,int));
190/// ----
191/// {a:1,b:2}
192///
193/// query T
194/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3,'c':4});
195/// ----
196/// {a:1,b:3,c:4}
197///
198/// # implicit type cast
199/// query T
200/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3.0,'c':4.0});
201/// ----
202/// {a:1,b:3.0,c:4.0}
203/// ```
204#[function("map_cat(anymap, anymap) -> anymap")]
205fn map_cat(m1: Option<MapRef<'_>>, m2: Option<MapRef<'_>>) -> Result<Option<MapValue>, ExprError> {
206 match (m1, m2) {
207 (None, None) => Ok(None),
208 (Some(m), None) | (None, Some(m)) => Ok(Some(m.to_owned_scalar())),
209 (Some(m1), Some(m2)) => Ok(Some(MapValue::concat(m1, m2))),
210 }
211}
212
213/// Inserts a key-value pair into the map. If the key already exists, the value is updated.
214///
215/// # Example
216///
217/// ```slt
218/// query T
219/// select map_insert(map{'a':1, 'b':2}, 'c', 3);
220/// ----
221/// {a:1,b:2,c:3}
222///
223/// query T
224/// select map_insert(map{'a':1, 'b':2}, 'b', 4);
225/// ----
226/// {a:1,b:4}
227/// ```
228///
229/// TODO: support variadic arguments
230#[function("map_insert(anymap, any, any) -> anymap")]
231fn map_insert(
232 map: MapRef<'_>,
233 key: Option<ScalarRefImpl<'_>>,
234 value: Option<ScalarRefImpl<'_>>,
235) -> MapValue {
236 let Some(key) = key else {
237 return map.to_owned_scalar();
238 };
239 MapValue::insert(map, key.into_scalar_impl(), value.to_owned_datum())
240}
241
242/// Deletes a key-value pair from the map.
243///
244/// # Example
245///
246/// ```slt
247/// query T
248/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'b');
249/// ----
250/// {a:1,c:3}
251///
252/// query T
253/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'd');
254/// ----
255/// {a:1,b:2,c:3}
256/// ```
257///
258/// TODO: support variadic arguments
259#[function("map_delete(anymap, any) -> anymap")]
260fn map_delete(map: MapRef<'_>, key: Option<ScalarRefImpl<'_>>) -> MapValue {
261 let Some(key) = key else {
262 return map.to_owned_scalar();
263 };
264 MapValue::delete(map, key)
265}
266
267/// # Example
268///
269/// ```slt
270/// query T
271/// select map_keys(map{'a':1, 'b':2, 'c':3});
272/// ----
273/// {a,b,c}
274/// ```
275#[function(
276 "map_keys(anymap) -> anyarray",
277 type_infer = "|args|{
278 Ok(DataType::list(args[0].as_map().key().clone()))
279 }"
280)]
281fn map_keys(map: MapRef<'_>) -> ListValue {
282 map.into_kv().0.to_owned_scalar()
283}
284
285/// # Example
286///
287/// ```slt
288/// query T
289/// select map_values(map{'a':1, 'b':2, 'c':3});
290/// ----
291/// {1,2,3}
292/// ```
293#[function(
294 "map_values(anymap) -> anyarray",
295 type_infer = "|args|{
296 Ok(DataType::list(args[0].as_map().value().clone()))
297 }"
298)]
299fn map_values(map: MapRef<'_>) -> ListValue {
300 map.into_kv().1.to_owned_scalar()
301}
302
303/// # Example
304///
305/// ```slt
306/// query T
307/// select map_entries(map{'a':1, 'b':2, 'c':3});
308/// ----
309/// {"(a,1)","(b,2)","(c,3)"}
310/// ```
311#[function(
312 "map_entries(anymap) -> anyarray",
313 type_infer = "|args|{
314 Ok(args[0].as_map().clone().into_list())
315 }"
316)]
317fn map_entries(map: MapRef<'_>) -> ListValue {
318 map.into_inner().to_owned_scalar()
319}
320
321#[cfg(test)]
322mod tests {
323 use risingwave_common::array::DataChunk;
324 use risingwave_common::row::Row;
325 use risingwave_common::test_prelude::DataChunkTestExt;
326 use risingwave_common::types::ToOwnedDatum;
327 use risingwave_common::util::iter_util::ZipEqDebug;
328 use risingwave_expr::expr::build_from_pretty;
329
330 #[tokio::test]
331 async fn test_row_expr() {
332 let expr = build_from_pretty("(row:struct<a_int4,b_int4,c_int4> $0:int4 $1:int4 $2:int4)");
333 let (input, expected) = DataChunk::from_pretty(
334 "i i i <i,i,i>
335 1 2 3 (1,2,3)
336 4 2 1 (4,2,1)
337 9 1 3 (9,1,3)
338 1 1 1 (1,1,1)",
339 )
340 .split_column_at(3);
341
342 // test eval
343 let output = expr.eval(&input).await.unwrap();
344 assert_eq!(&output, expected.column_at(0));
345
346 // test eval_row
347 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
348 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
349 assert_eq!(result, expected.datum_at(0).to_owned_datum());
350 }
351 }
352}