risingwave_expr_impl/scalar/
array_transform.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18use risingwave_common::array::{ArrayRef, DataChunk};
19use risingwave_common::row::OwnedRow;
20use risingwave_common::types::{DataType, Datum, ListValue, ScalarImpl};
21use risingwave_expr::expr::{BoxedExpression, Expression};
22use risingwave_expr::{Result, build_function};
23
24#[derive(Debug)]
25struct ArrayTransformExpression {
26    array: BoxedExpression,
27    lambda: BoxedExpression,
28}
29
30#[async_trait]
31impl Expression for ArrayTransformExpression {
32    fn return_type(&self) -> DataType {
33        DataType::List(Box::new(self.lambda.return_type()))
34    }
35
36    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
37        let lambda_input = self.array.eval(input).await?;
38        let lambda_input = Arc::unwrap_or_clone(lambda_input).into_list();
39        let new_list = lambda_input
40            .map_inner(|flatten_input| async move {
41                let flatten_len = flatten_input.len();
42                let chunk = DataChunk::new(vec![Arc::new(flatten_input)], flatten_len);
43                self.lambda.eval(&chunk).await.map(Arc::unwrap_or_clone)
44            })
45            .await?;
46        Ok(Arc::new(new_list.into()))
47    }
48
49    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
50        let lambda_input = self.array.eval_row(input).await?;
51        let lambda_input = lambda_input.map(ScalarImpl::into_list);
52        if let Some(lambda_input) = lambda_input {
53            let len = lambda_input.len();
54            let chunk = DataChunk::new(vec![Arc::new(lambda_input.into_array())], len);
55            let new_vals = self.lambda.eval(&chunk).await?;
56            let new_list = ListValue::new(Arc::unwrap_or_clone(new_vals));
57            Ok(Some(new_list.into()))
58        } else {
59            Ok(None)
60        }
61    }
62}
63
64#[build_function("array_transform(anyarray, any) -> anyarray")]
65fn build(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
66    let [array, lambda] = <[BoxedExpression; 2]>::try_from(children).unwrap();
67    Ok(Box::new(ArrayTransformExpression { array, lambda }))
68}
69
70#[cfg(test)]
71mod tests {
72    use risingwave_common::array::{DataChunk, DataChunkTestExt};
73    use risingwave_common::row::Row;
74    use risingwave_common::types::ToOwnedDatum;
75    use risingwave_common::util::iter_util::ZipEqDebug;
76    use risingwave_expr::expr::build_from_pretty;
77
78    #[tokio::test]
79    async fn test_array_transform() {
80        let expr =
81            build_from_pretty("(array_transform:int4[] $0:int4[] (multiply:int4 $0:int4 2:int4))");
82        let (input, expected) = DataChunk::from_pretty(
83            "i[]     i[]
84             {1,2,3} {2,4,6}",
85        )
86        .split_column_at(1);
87
88        // test eval
89        let output = expr.eval(&input).await.unwrap();
90        assert_eq!(&output, expected.column_at(0));
91
92        // test eval_row
93        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
94            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
95            assert_eq!(result, expected.datum_at(0).to_owned_datum());
96        }
97    }
98}