risingwave_expr_impl/scalar/
field.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk};
17use risingwave_common::row::OwnedRow;
18use risingwave_common::types::{DataType, Datum, ScalarImpl};
19use risingwave_expr::expr::{BoxedExpression, Expression};
20use risingwave_expr::{Result, build_function};
21
22/// `FieldExpression` access a field from a struct.
23#[derive(Debug)]
24pub struct FieldExpression {
25    return_type: DataType,
26    input: BoxedExpression,
27    index: usize,
28}
29
30#[async_trait::async_trait]
31impl Expression for FieldExpression {
32    fn return_type(&self) -> DataType {
33        self.return_type.clone()
34    }
35
36    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
37        let array = self.input.eval(input).await?;
38        if let ArrayImpl::Struct(struct_array) = array.as_ref() {
39            Ok(struct_array.field_at(self.index).clone())
40        } else {
41            Err(anyhow!("expects a struct array ref").into())
42        }
43    }
44
45    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
46        let struct_datum = self.input.eval_row(input).await?;
47        struct_datum
48            .map(|s| match s {
49                ScalarImpl::Struct(v) => Ok(v.fields()[self.index].clone()),
50                _ => Err(anyhow!("expects a struct array ref").into()),
51            })
52            .transpose()
53            .map(|x| x.flatten())
54    }
55}
56
57#[build_function("field(struct, int4) -> any", type_infer = "unreachable")]
58fn build(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
59    // Field `func_call_node` have 2 child nodes, the first is Field `FuncCall` or
60    // `InputRef`, the second is i32 `Literal`.
61    let [input, index]: [_; 2] = children.try_into().unwrap();
62    let index = index.eval_const()?.unwrap().into_int32() as usize;
63    Ok(Box::new(FieldExpression {
64        return_type,
65        input,
66        index,
67    }))
68}
69
70#[cfg(test)]
71mod tests {
72    use risingwave_common::array::{DataChunk, DataChunkTestExt};
73    use risingwave_common::row::Row;
74    use risingwave_common::types::ToOwnedDatum;
75    use risingwave_common::util::iter_util::ZipEqDebug;
76    use risingwave_expr::expr::build_from_pretty;
77
78    #[tokio::test]
79    async fn test_field_expr() {
80        let expr = build_from_pretty("(field:int4 $0:struct<a_int4,b_float4> 0:int4)");
81        let (input, expected) = DataChunk::from_pretty(
82            "<i,f>   i
83             (1,2.0) 1
84             (2,2.0) 2
85             (3,2.0) 3",
86        )
87        .split_column_at(1);
88
89        // test eval
90        let output = expr.eval(&input).await.unwrap();
91        assert_eq!(&output, expected.column_at(0));
92
93        // test eval_row
94        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
95            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
96            assert_eq!(result, expected.datum_at(0).to_owned_datum());
97        }
98    }
99}