risingwave_expr_impl/scalar/array.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::StructValue;
16use risingwave_common::row::Row;
17use risingwave_common::types::{
18 DataType, ListRef, MapRef, MapType, MapValue, ScalarRef, ScalarRefImpl, ToOwnedDatum,
19};
20use risingwave_expr::{ExprError, function};
21
22use super::array_positions::array_position;
23
24#[function("array(...) -> anyarray", type_infer = "unreachable")]
25fn array(row: impl Row, writer: &mut impl risingwave_common::array::ListWrite) {
26 writer.write_iter(row.iter());
27}
28
29#[function("row(...) -> struct", type_infer = "unreachable")]
30fn row_(row: impl Row) -> StructValue {
31 StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect())
32}
33
34fn map_from_key_values_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
35 let map = MapType::try_from_kv(
36 args[0].as_list_elem().clone(),
37 args[1].as_list_elem().clone(),
38 )
39 .map_err(ExprError::Custom)?;
40 Ok(map.into())
41}
42
43fn map_from_entries_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
44 let map =
45 MapType::try_from_entries(args[0].as_list_elem().clone()).map_err(ExprError::Custom)?;
46 Ok(map.into())
47}
48
49/// # Example
50///
51/// ```slt
52/// query T
53/// select map_from_key_values(null::int[], array[1,2,3]);
54/// ----
55/// NULL
56///
57/// query T
58/// select map_from_key_values(array['a','b','c'], array[1,2,3]);
59/// ----
60/// {a:1,b:2,c:3}
61/// ```
62#[function(
63 "map_from_key_values(anyarray, anyarray) -> anymap",
64 type_infer = "map_from_key_values_type_infer"
65)]
66fn map_from_key_values(keys: ListRef<'_>, values: ListRef<'_>) -> Result<MapValue, ExprError> {
67 MapValue::try_from_kv(keys.to_owned_scalar(), values.to_owned_scalar())
68 .map_err(ExprError::Custom)
69}
70
71#[function(
72 "map_from_entries(anyarray) -> anymap",
73 type_infer = "map_from_entries_type_infer"
74)]
75fn map_from_entries(entries: ListRef<'_>) -> Result<MapValue, ExprError> {
76 MapValue::try_from_entries(entries.to_owned_scalar()).map_err(ExprError::Custom)
77}
78
79/// # Example
80///
81/// ```slt
82/// query T
83/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 3);
84/// ----
85/// 300
86///
87/// query T
88/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), '3');
89/// ----
90/// 300
91///
92/// query error
93/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 1.0);
94/// ----
95/// db error: ERROR: Failed to run the query
96///
97/// Caused by these errors (recent errors listed first):
98/// 1: Failed to bind expression: map_access(map_from_key_values(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0)
99/// 2: Bind error: Cannot access numeric in map(integer,integer)
100///
101///
102/// query T
103/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'a');
104/// ----
105/// 1
106///
107/// query T
108/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'd');
109/// ----
110/// NULL
111///
112/// query T
113/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), null);
114/// ----
115/// NULL
116/// ```
117#[function("map_access(anymap, any) -> any", type_infer = "unreachable")]
118fn map_access<'a>(
119 map: MapRef<'a>,
120 key: ScalarRefImpl<'_>,
121) -> Result<Option<ScalarRefImpl<'a>>, ExprError> {
122 // FIXME: DatumRef in return value is not support by the macro yet.
123
124 let (keys, values) = map.into_kv();
125 let idx = array_position(keys, Some(key))?;
126 match idx {
127 Some(idx) => Ok(values.get((idx - 1) as usize).unwrap()),
128 None => Ok(None),
129 }
130}
131
132/// ```slt
133/// query T
134/// select
135/// map_contains(MAP{1:1}, 1),
136/// map_contains(MAP{1:1}, 2),
137/// map_contains(MAP{1:1}, NULL),
138/// map_contains(MAP{1:1}, '1'),
139/// map_contains(MAP{'a':'1','b':'2'}, 'ab');
140/// ----
141/// t f NULL t f
142///
143///
144/// query error
145/// select map_contains(MAP{1:1}, 1.0);
146/// ----
147/// db error: ERROR: Failed to run the query
148///
149/// Caused by these errors (recent errors listed first):
150/// 1: Failed to bind expression: map_contains(MAP {1: 1}, 1.0)
151/// 2: Bind error: Cannot check if numeric exists in map(integer,integer)
152///
153///
154/// query error
155/// select map_contains(MAP{1:1}, NULL::varchar);
156/// ----
157/// db error: ERROR: Failed to run the query
158///
159/// Caused by these errors (recent errors listed first):
160/// 1: Failed to bind expression: map_contains(MAP {1: 1}, CAST(NULL AS CHARACTER VARYING))
161/// 2: Bind error: Cannot check if character varying exists in map(integer,integer)
162/// ```
163#[function("map_contains(anymap, any) -> boolean")]
164fn map_contains(map: MapRef<'_>, key: ScalarRefImpl<'_>) -> Result<bool, ExprError> {
165 let (keys, _values) = map.into_kv();
166 let idx = array_position(keys, Some(key))?;
167 Ok(idx.is_some())
168}
169
170/// ```slt
171/// query I
172/// select
173/// map_length(NULL::map(int,int)),
174/// map_length(MAP {}::map(int,int)),
175/// map_length(MAP {1:1,2:2}::map(int,int))
176/// ----
177/// NULL 0 2
178/// ```
179#[function("map_length(anymap) -> int4")]
180fn map_length<T: TryFrom<usize>>(map: MapRef<'_>) -> Result<T, ExprError> {
181 map.len().try_into().map_err(|_| ExprError::NumericOverflow)
182}
183
184/// If both `m1` and `m2` have a value with the same key, then the output map contains the value from `m2`.
185///
186/// ```slt
187/// query T
188/// select map_cat(MAP{'a':1,'b':2},null::map(varchar,int));
189/// ----
190/// {a:1,b:2}
191///
192/// query T
193/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3,'c':4});
194/// ----
195/// {a:1,b:3,c:4}
196///
197/// # implicit type cast
198/// query T
199/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3.0,'c':4.0});
200/// ----
201/// {a:1,b:3.0,c:4.0}
202/// ```
203#[function("map_cat(anymap, anymap) -> anymap")]
204fn map_cat(m1: Option<MapRef<'_>>, m2: Option<MapRef<'_>>) -> Option<MapValue> {
205 match (m1, m2) {
206 (None, None) => None,
207 (Some(m), None) | (None, Some(m)) => Some(m.to_owned_scalar()),
208 (Some(m1), Some(m2)) => Some(MapValue::concat(m1, m2)),
209 }
210}
211
212/// Inserts a key-value pair into the map. If the key already exists, the value is updated.
213///
214/// # Example
215///
216/// ```slt
217/// query T
218/// select map_insert(map{'a':1, 'b':2}, 'c', 3);
219/// ----
220/// {a:1,b:2,c:3}
221///
222/// query T
223/// select map_insert(map{'a':1, 'b':2}, 'b', 4);
224/// ----
225/// {a:1,b:4}
226/// ```
227///
228/// TODO: support variadic arguments
229#[function("map_insert(anymap, any, any) -> anymap")]
230fn map_insert(
231 map: MapRef<'_>,
232 key: Option<ScalarRefImpl<'_>>,
233 value: Option<ScalarRefImpl<'_>>,
234) -> MapValue {
235 let Some(key) = key else {
236 return map.to_owned_scalar();
237 };
238 MapValue::insert(map, key.into_scalar_impl(), value.to_owned_datum())
239}
240
241/// Deletes a key-value pair from the map.
242///
243/// # Example
244///
245/// ```slt
246/// query T
247/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'b');
248/// ----
249/// {a:1,c:3}
250///
251/// query T
252/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'd');
253/// ----
254/// {a:1,b:2,c:3}
255/// ```
256///
257/// TODO: support variadic arguments
258#[function("map_delete(anymap, any) -> anymap")]
259fn map_delete(map: MapRef<'_>, key: Option<ScalarRefImpl<'_>>) -> MapValue {
260 let Some(key) = key else {
261 return map.to_owned_scalar();
262 };
263 MapValue::delete(map, key)
264}
265
266/// # Example
267///
268/// ```slt
269/// query T
270/// select map_keys(map{'a':1, 'b':2, 'c':3});
271/// ----
272/// {a,b,c}
273/// ```
274#[function(
275 "map_keys(anymap) -> anyarray",
276 type_infer = "|args|{
277 Ok(DataType::list(args[0].as_map().key().clone()))
278 }"
279)]
280fn map_keys(map: MapRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
281 writer.write_iter(map.into_kv().0.iter());
282}
283
284/// # Example
285///
286/// ```slt
287/// query T
288/// select map_values(map{'a':1, 'b':2, 'c':3});
289/// ----
290/// {1,2,3}
291/// ```
292#[function(
293 "map_values(anymap) -> anyarray",
294 type_infer = "|args|{
295 Ok(DataType::list(args[0].as_map().value().clone()))
296 }"
297)]
298fn map_values(map: MapRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
299 writer.write_iter(map.into_kv().1.iter());
300}
301
302/// # Example
303///
304/// ```slt
305/// query T
306/// select map_entries(map{'a':1, 'b':2, 'c':3});
307/// ----
308/// {"(a,1)","(b,2)","(c,3)"}
309/// ```
310#[function(
311 "map_entries(anymap) -> anyarray",
312 type_infer = "|args|{
313 Ok(args[0].as_map().clone().into_list())
314 }"
315)]
316fn map_entries(map: MapRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
317 writer.write_iter(map.into_inner().iter());
318}
319
320#[cfg(test)]
321mod tests {
322 use risingwave_common::array::DataChunk;
323 use risingwave_common::row::Row;
324 use risingwave_common::test_prelude::DataChunkTestExt;
325 use risingwave_common::types::ToOwnedDatum;
326 use risingwave_common::util::iter_util::ZipEqDebug;
327 use risingwave_expr::expr::build_from_pretty;
328
329 #[tokio::test]
330 async fn test_row_expr() {
331 let expr = build_from_pretty("(row:struct<a_int4,b_int4,c_int4> $0:int4 $1:int4 $2:int4)");
332 let (input, expected) = DataChunk::from_pretty(
333 "i i i <i,i,i>
334 1 2 3 (1,2,3)
335 4 2 1 (4,2,1)
336 9 1 3 (9,1,3)
337 1 1 1 (1,1,1)",
338 )
339 .split_column_at(3);
340
341 // test eval
342 let output = expr.eval(&input).await.unwrap();
343 assert_eq!(&output, expected.column_at(0));
344
345 // test eval_row
346 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
347 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
348 assert_eq!(result, expected.datum_at(0).to_owned_datum());
349 }
350 }
351}