risingwave_expr_impl/scalar/
array_transform.rs1use std::sync::Arc;
16
17use async_trait::async_trait;
18use risingwave_common::array::{ArrayRef, DataChunk};
19use risingwave_common::row::OwnedRow;
20use risingwave_common::types::{DataType, Datum, ListValue, ScalarImpl};
21use risingwave_expr::expr::{BoxedExpression, Expression};
22use risingwave_expr::{Result, build_function};
23
24#[derive(Debug)]
25struct ArrayTransformExpression {
26 array: BoxedExpression,
27 lambda: BoxedExpression,
28}
29
30#[async_trait]
31impl Expression for ArrayTransformExpression {
32 fn return_type(&self) -> DataType {
33 DataType::List(Box::new(self.lambda.return_type()))
34 }
35
36 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
37 let lambda_input = self.array.eval(input).await?;
38 let lambda_input = Arc::unwrap_or_clone(lambda_input).into_list();
39 let new_list = lambda_input
40 .map_inner(|flatten_input| async move {
41 let flatten_len = flatten_input.len();
42 let chunk = DataChunk::new(vec![Arc::new(flatten_input)], flatten_len);
43 self.lambda.eval(&chunk).await.map(Arc::unwrap_or_clone)
44 })
45 .await?;
46 Ok(Arc::new(new_list.into()))
47 }
48
49 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
50 let lambda_input = self.array.eval_row(input).await?;
51 let lambda_input = lambda_input.map(ScalarImpl::into_list);
52 if let Some(lambda_input) = lambda_input {
53 let len = lambda_input.len();
54 let chunk = DataChunk::new(vec![Arc::new(lambda_input.into_array())], len);
55 let new_vals = self.lambda.eval(&chunk).await?;
56 let new_list = ListValue::new(Arc::unwrap_or_clone(new_vals));
57 Ok(Some(new_list.into()))
58 } else {
59 Ok(None)
60 }
61 }
62}
63
64#[build_function("array_transform(anyarray, any) -> anyarray")]
65fn build(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
66 let [array, lambda] = <[BoxedExpression; 2]>::try_from(children).unwrap();
67 Ok(Box::new(ArrayTransformExpression { array, lambda }))
68}
69
70#[cfg(test)]
71mod tests {
72 use risingwave_common::array::{DataChunk, DataChunkTestExt};
73 use risingwave_common::row::Row;
74 use risingwave_common::types::ToOwnedDatum;
75 use risingwave_common::util::iter_util::ZipEqDebug;
76 use risingwave_expr::expr::build_from_pretty;
77
78 #[tokio::test]
79 async fn test_array_transform() {
80 let expr =
81 build_from_pretty("(array_transform:int4[] $0:int4[] (multiply:int4 $0:int4 2:int4))");
82 let (input, expected) = DataChunk::from_pretty(
83 "i[] i[]
84 {1,2,3} {2,4,6}",
85 )
86 .split_column_at(1);
87
88 let output = expr.eval(&input).await.unwrap();
90 assert_eq!(&output, expected.column_at(0));
91
92 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
94 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
95 assert_eq!(result, expected.datum_at(0).to_owned_datum());
96 }
97 }
98}