risingwave_expr_impl/scalar/
array.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::StructValue;
16use risingwave_common::row::Row;
17use risingwave_common::types::{
18    DataType, ListRef, MapRef, MapType, MapValue, ScalarRef, ScalarRefImpl, ToOwnedDatum,
19};
20use risingwave_expr::{ExprError, function};
21
22use super::array_positions::array_position;
23
24#[function("array(...) -> anyarray", type_infer = "unreachable")]
25fn array(row: impl Row, writer: &mut impl risingwave_common::array::ListWrite) {
26    writer.write_iter(row.iter());
27}
28
29#[function("row(...) -> struct", type_infer = "unreachable")]
30fn row_(row: impl Row) -> StructValue {
31    StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect())
32}
33
34fn map_from_key_values_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
35    let map = MapType::try_from_kv(
36        args[0].as_list_elem().clone(),
37        args[1].as_list_elem().clone(),
38    )
39    .map_err(ExprError::Custom)?;
40    Ok(map.into())
41}
42
43fn map_from_entries_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
44    let map =
45        MapType::try_from_entries(args[0].as_list_elem().clone()).map_err(ExprError::Custom)?;
46    Ok(map.into())
47}
48
49/// # Example
50///
51/// ```slt
52/// query T
53/// select map_from_key_values(null::int[], array[1,2,3]);
54/// ----
55/// NULL
56///
57/// query T
58/// select map_from_key_values(array['a','b','c'], array[1,2,3]);
59/// ----
60/// {a:1,b:2,c:3}
61/// ```
62#[function(
63    "map_from_key_values(anyarray, anyarray) -> anymap",
64    type_infer = "map_from_key_values_type_infer"
65)]
66fn map_from_key_values(keys: ListRef<'_>, values: ListRef<'_>) -> Result<MapValue, ExprError> {
67    MapValue::try_from_kv(keys.to_owned_scalar(), values.to_owned_scalar())
68        .map_err(ExprError::Custom)
69}
70
71#[function(
72    "map_from_entries(anyarray) -> anymap",
73    type_infer = "map_from_entries_type_infer"
74)]
75fn map_from_entries(entries: ListRef<'_>) -> Result<MapValue, ExprError> {
76    MapValue::try_from_entries(entries.to_owned_scalar()).map_err(ExprError::Custom)
77}
78
79/// # Example
80///
81/// ```slt
82/// query T
83/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 3);
84/// ----
85/// 300
86///
87/// query T
88/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), '3');
89/// ----
90/// 300
91///
92/// query error
93/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 1.0);
94/// ----
95/// db error: ERROR: Failed to run the query
96///
97/// Caused by these errors (recent errors listed first):
98///   1: Failed to bind expression: map_access(map_from_key_values(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0)
99///   2: Bind error: Cannot access numeric in map(integer,integer)
100///
101///
102/// query T
103/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'a');
104/// ----
105/// 1
106///
107/// query T
108/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'd');
109/// ----
110/// NULL
111///
112/// query T
113/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), null);
114/// ----
115/// NULL
116/// ```
117#[function("map_access(anymap, any) -> any", type_infer = "unreachable")]
118fn map_access<'a>(
119    map: MapRef<'a>,
120    key: ScalarRefImpl<'_>,
121) -> Result<Option<ScalarRefImpl<'a>>, ExprError> {
122    // FIXME: DatumRef in return value is not support by the macro yet.
123
124    let (keys, values) = map.into_kv();
125    let idx = array_position(keys, Some(key))?;
126    match idx {
127        Some(idx) => Ok(values.get((idx - 1) as usize).unwrap()),
128        None => Ok(None),
129    }
130}
131
132/// ```slt
133/// query T
134/// select
135///     map_contains(MAP{1:1}, 1),
136///     map_contains(MAP{1:1}, 2),
137///     map_contains(MAP{1:1}, NULL),
138///     map_contains(MAP{1:1}, '1'),
139///     map_contains(MAP{'a':'1','b':'2'}, 'ab');
140/// ----
141/// t f NULL t f
142///
143///
144/// query error
145/// select map_contains(MAP{1:1}, 1.0);
146/// ----
147/// db error: ERROR: Failed to run the query
148///
149/// Caused by these errors (recent errors listed first):
150///   1: Failed to bind expression: map_contains(MAP {1: 1}, 1.0)
151///   2: Bind error: Cannot check if numeric exists in map(integer,integer)
152///
153///
154/// query error
155/// select map_contains(MAP{1:1}, NULL::varchar);
156/// ----
157/// db error: ERROR: Failed to run the query
158///
159/// Caused by these errors (recent errors listed first):
160///   1: Failed to bind expression: map_contains(MAP {1: 1}, CAST(NULL AS CHARACTER VARYING))
161///   2: Bind error: Cannot check if character varying exists in map(integer,integer)
162/// ```
163#[function("map_contains(anymap, any) -> boolean")]
164fn map_contains(map: MapRef<'_>, key: ScalarRefImpl<'_>) -> Result<bool, ExprError> {
165    let (keys, _values) = map.into_kv();
166    let idx = array_position(keys, Some(key))?;
167    Ok(idx.is_some())
168}
169
170/// ```slt
171/// query I
172/// select
173///     map_length(NULL::map(int,int)),
174///     map_length(MAP {}::map(int,int)),
175///     map_length(MAP {1:1,2:2}::map(int,int))
176/// ----
177/// NULL 0 2
178/// ```
179#[function("map_length(anymap) -> int4")]
180fn map_length<T: TryFrom<usize>>(map: MapRef<'_>) -> Result<T, ExprError> {
181    map.len().try_into().map_err(|_| ExprError::NumericOverflow)
182}
183
184/// If both `m1` and `m2` have a value with the same key, then the output map contains the value from `m2`.
185///
186/// ```slt
187/// query T
188/// select map_cat(MAP{'a':1,'b':2},null::map(varchar,int));
189/// ----
190/// {a:1,b:2}
191///
192/// query T
193/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3,'c':4});
194/// ----
195/// {a:1,b:3,c:4}
196///
197/// # implicit type cast
198/// query T
199/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3.0,'c':4.0});
200/// ----
201/// {a:1,b:3.0,c:4.0}
202/// ```
203#[function("map_cat(anymap, anymap) -> anymap")]
204fn map_cat(m1: Option<MapRef<'_>>, m2: Option<MapRef<'_>>) -> Option<MapValue> {
205    match (m1, m2) {
206        (None, None) => None,
207        (Some(m), None) | (None, Some(m)) => Some(m.to_owned_scalar()),
208        (Some(m1), Some(m2)) => Some(MapValue::concat(m1, m2)),
209    }
210}
211
212/// Inserts a key-value pair into the map. If the key already exists, the value is updated.
213///
214/// # Example
215///
216/// ```slt
217/// query T
218/// select map_insert(map{'a':1, 'b':2}, 'c', 3);
219/// ----
220/// {a:1,b:2,c:3}
221///
222/// query T
223/// select map_insert(map{'a':1, 'b':2}, 'b', 4);
224/// ----
225/// {a:1,b:4}
226/// ```
227///
228/// TODO: support variadic arguments
229#[function("map_insert(anymap, any, any) -> anymap")]
230fn map_insert(
231    map: MapRef<'_>,
232    key: Option<ScalarRefImpl<'_>>,
233    value: Option<ScalarRefImpl<'_>>,
234) -> MapValue {
235    let Some(key) = key else {
236        return map.to_owned_scalar();
237    };
238    MapValue::insert(map, key.into_scalar_impl(), value.to_owned_datum())
239}
240
241/// Deletes a key-value pair from the map.
242///
243/// # Example
244///
245/// ```slt
246/// query T
247/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'b');
248/// ----
249/// {a:1,c:3}
250///
251/// query T
252/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'd');
253/// ----
254/// {a:1,b:2,c:3}
255/// ```
256///
257/// TODO: support variadic arguments
258#[function("map_delete(anymap, any) -> anymap")]
259fn map_delete(map: MapRef<'_>, key: Option<ScalarRefImpl<'_>>) -> MapValue {
260    let Some(key) = key else {
261        return map.to_owned_scalar();
262    };
263    MapValue::delete(map, key)
264}
265
266/// # Example
267///
268/// ```slt
269/// query T
270/// select map_keys(map{'a':1, 'b':2, 'c':3});
271/// ----
272/// {a,b,c}
273/// ```
274#[function(
275    "map_keys(anymap) -> anyarray",
276    type_infer = "|args|{
277        Ok(DataType::list(args[0].as_map().key().clone()))
278    }"
279)]
280fn map_keys(map: MapRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
281    writer.write_iter(map.into_kv().0.iter());
282}
283
284/// # Example
285///
286/// ```slt
287/// query T
288/// select map_values(map{'a':1, 'b':2, 'c':3});
289/// ----
290/// {1,2,3}
291/// ```
292#[function(
293    "map_values(anymap) -> anyarray",
294    type_infer = "|args|{
295        Ok(DataType::list(args[0].as_map().value().clone()))
296    }"
297)]
298fn map_values(map: MapRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
299    writer.write_iter(map.into_kv().1.iter());
300}
301
302/// # Example
303///
304/// ```slt
305/// query T
306/// select map_entries(map{'a':1, 'b':2, 'c':3});
307/// ----
308/// {"(a,1)","(b,2)","(c,3)"}
309/// ```
310#[function(
311    "map_entries(anymap) -> anyarray",
312    type_infer = "|args|{
313        Ok(args[0].as_map().clone().into_list())
314    }"
315)]
316fn map_entries(map: MapRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
317    writer.write_iter(map.into_inner().iter());
318}
319
320#[cfg(test)]
321mod tests {
322    use risingwave_common::array::DataChunk;
323    use risingwave_common::row::Row;
324    use risingwave_common::test_prelude::DataChunkTestExt;
325    use risingwave_common::types::ToOwnedDatum;
326    use risingwave_common::util::iter_util::ZipEqDebug;
327    use risingwave_expr::expr::build_from_pretty;
328
329    #[tokio::test]
330    async fn test_row_expr() {
331        let expr = build_from_pretty("(row:struct<a_int4,b_int4,c_int4> $0:int4 $1:int4 $2:int4)");
332        let (input, expected) = DataChunk::from_pretty(
333            "i i i <i,i,i>
334             1 2 3 (1,2,3)
335             4 2 1 (4,2,1)
336             9 1 3 (9,1,3)
337             1 1 1 (1,1,1)",
338        )
339        .split_column_at(3);
340
341        // test eval
342        let output = expr.eval(&input).await.unwrap();
343        assert_eq!(&output, expected.column_at(0));
344
345        // test eval_row
346        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
347            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
348            assert_eq!(result, expected.datum_at(0).to_owned_datum());
349        }
350    }
351}