risingwave_expr_impl/scalar/
field.rs1use anyhow::anyhow;
16use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk};
17use risingwave_common::row::OwnedRow;
18use risingwave_common::types::{DataType, Datum, ScalarImpl};
19use risingwave_expr::expr::{BoxedExpression, Expression};
20use risingwave_expr::{Result, build_function};
21
22#[derive(Debug)]
24pub struct FieldExpression {
25 return_type: DataType,
26 input: BoxedExpression,
27 index: usize,
28}
29
30#[async_trait::async_trait]
31impl Expression for FieldExpression {
32 fn return_type(&self) -> DataType {
33 self.return_type.clone()
34 }
35
36 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
37 let array = self.input.eval(input).await?;
38 if let ArrayImpl::Struct(struct_array) = array.as_ref() {
39 Ok(struct_array.field_at(self.index).clone())
40 } else {
41 Err(anyhow!("expects a struct array ref").into())
42 }
43 }
44
45 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
46 let struct_datum = self.input.eval_row(input).await?;
47 struct_datum
48 .map(|s| match s {
49 ScalarImpl::Struct(v) => Ok(v.fields()[self.index].clone()),
50 _ => Err(anyhow!("expects a struct array ref").into()),
51 })
52 .transpose()
53 .map(|x| x.flatten())
54 }
55}
56
57#[build_function("field(struct, int4) -> any", type_infer = "unreachable")]
58fn build(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
59 let [input, index]: [_; 2] = children.try_into().unwrap();
62 let index = index.eval_const()?.unwrap().into_int32() as usize;
63 Ok(Box::new(FieldExpression {
64 return_type,
65 input,
66 index,
67 }))
68}
69
70#[cfg(test)]
71mod tests {
72 use risingwave_common::array::{DataChunk, DataChunkTestExt};
73 use risingwave_common::row::Row;
74 use risingwave_common::types::ToOwnedDatum;
75 use risingwave_common::util::iter_util::ZipEqDebug;
76 use risingwave_expr::expr::build_from_pretty;
77
78 #[tokio::test]
79 async fn test_field_expr() {
80 let expr = build_from_pretty("(field:int4 $0:struct<a_int4,b_float4> 0:int4)");
81 let (input, expected) = DataChunk::from_pretty(
82 "<i,f> i
83 (1,2.0) 1
84 (2,2.0) 2
85 (3,2.0) 3",
86 )
87 .split_column_at(1);
88
89 let output = expr.eval(&input).await.unwrap();
91 assert_eq!(&output, expected.column_at(0));
92
93 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
95 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
96 assert_eq!(result, expected.datum_at(0).to_owned_datum());
97 }
98 }
99}