risingwave_expr/expr/
expr_input_ref.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Index;
16
17use risingwave_common::array::{ArrayRef, DataChunk};
18use risingwave_common::row::OwnedRow;
19use risingwave_common::types::{DataType, Datum};
20use risingwave_pb::expr::ExprNode;
21
22use super::{BoxedExpression, Build};
23use crate::Result;
24use crate::expr::Expression;
25
26/// A reference to a column in input relation.
27#[derive(Debug, Clone)]
28pub struct InputRefExpression {
29    return_type: DataType,
30    idx: usize,
31}
32
33#[async_trait::async_trait]
34impl Expression for InputRefExpression {
35    fn return_type(&self) -> DataType {
36        self.return_type.clone()
37    }
38
39    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
40        Ok(input.column_at(self.idx).clone())
41    }
42
43    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
44        let cell = input.index(self.idx).as_ref().cloned();
45        Ok(cell)
46    }
47
48    fn input_ref_index(&self) -> Option<usize> {
49        Some(self.idx)
50    }
51}
52
53impl InputRefExpression {
54    pub fn new(return_type: DataType, idx: usize) -> Self {
55        InputRefExpression { return_type, idx }
56    }
57
58    /// Create an [`InputRefExpression`] from a protobuf expression.
59    ///
60    /// Panics if the protobuf expression is not an input reference.
61    pub fn from_prost(prost: &ExprNode) -> Self {
62        let ret_type = DataType::from(prost.get_return_type().unwrap());
63        let input_col_idx = prost.get_rex_node().unwrap().as_input_ref().unwrap();
64
65        Self {
66            return_type: ret_type,
67            idx: *input_col_idx as _,
68        }
69    }
70
71    pub fn index(&self) -> usize {
72        self.idx
73    }
74
75    pub fn eval_immut(&self, input: &DataChunk) -> Result<ArrayRef> {
76        Ok(input.column_at(self.idx).clone())
77    }
78}
79
80impl Build for InputRefExpression {
81    fn build(
82        prost: &ExprNode,
83        _build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
84    ) -> Result<Self> {
85        Ok(Self::from_prost(prost))
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use risingwave_common::row::OwnedRow;
92    use risingwave_common::types::{DataType, Datum};
93
94    use crate::expr::{Expression, InputRefExpression};
95
96    #[tokio::test]
97    async fn test_eval_row_input_ref() {
98        let datums: Vec<Datum> = vec![Some(1.into()), Some(2.into()), None];
99        let input_row = OwnedRow::new(datums.clone());
100
101        for (i, expected) in datums.iter().enumerate() {
102            let expr = InputRefExpression::new(DataType::Int32, i);
103            let result = expr.eval_row(&input_row).await.unwrap();
104            assert_eq!(*expected, result);
105        }
106    }
107}