risingwave_expr/expr/
expr_input_ref.rs
1use std::ops::Index;
16
17use risingwave_common::array::{ArrayRef, DataChunk};
18use risingwave_common::row::OwnedRow;
19use risingwave_common::types::{DataType, Datum};
20use risingwave_pb::expr::ExprNode;
21
22use super::{BoxedExpression, Build};
23use crate::Result;
24use crate::expr::Expression;
25
26#[derive(Debug, Clone)]
28pub struct InputRefExpression {
29 return_type: DataType,
30 idx: usize,
31}
32
33#[async_trait::async_trait]
34impl Expression for InputRefExpression {
35 fn return_type(&self) -> DataType {
36 self.return_type.clone()
37 }
38
39 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
40 Ok(input.column_at(self.idx).clone())
41 }
42
43 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
44 let cell = input.index(self.idx).as_ref().cloned();
45 Ok(cell)
46 }
47
48 fn input_ref_index(&self) -> Option<usize> {
49 Some(self.idx)
50 }
51}
52
53impl InputRefExpression {
54 pub fn new(return_type: DataType, idx: usize) -> Self {
55 InputRefExpression { return_type, idx }
56 }
57
58 pub fn from_prost(prost: &ExprNode) -> Self {
62 let ret_type = DataType::from(prost.get_return_type().unwrap());
63 let input_col_idx = prost.get_rex_node().unwrap().as_input_ref().unwrap();
64
65 Self {
66 return_type: ret_type,
67 idx: *input_col_idx as _,
68 }
69 }
70
71 pub fn index(&self) -> usize {
72 self.idx
73 }
74
75 pub fn eval_immut(&self, input: &DataChunk) -> Result<ArrayRef> {
76 Ok(input.column_at(self.idx).clone())
77 }
78}
79
80impl Build for InputRefExpression {
81 fn build(
82 prost: &ExprNode,
83 _build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
84 ) -> Result<Self> {
85 Ok(Self::from_prost(prost))
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use risingwave_common::row::OwnedRow;
92 use risingwave_common::types::{DataType, Datum};
93
94 use crate::expr::{Expression, InputRefExpression};
95
96 #[tokio::test]
97 async fn test_eval_row_input_ref() {
98 let datums: Vec<Datum> = vec![Some(1.into()), Some(2.into()), None];
99 let input_row = OwnedRow::new(datums.clone());
100
101 for (i, expected) in datums.iter().enumerate() {
102 let expr = InputRefExpression::new(DataType::Int32, i);
103 let result = expr.eval_row(&input_row).await.unwrap();
104 assert_eq!(*expected, result);
105 }
106 }
107}