risingwave_expr/expr/wrapper/
strict.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use risingwave_common::array::{ArrayRef, DataChunk};
17use risingwave_common::row::OwnedRow;
18use risingwave_common::types::{DataType, Datum};
19
20use crate::ExprError;
21use crate::error::Result;
22use crate::expr::{Expression, ValueImpl};
23
24/// A wrapper of [`Expression`] that only keeps the first error if multiple errors are returned.
25pub(crate) struct Strict<E> {
26    inner: E,
27}
28
29impl<E> std::fmt::Debug for Strict<E>
30where
31    E: std::fmt::Debug,
32{
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("Strict")
35            .field("inner", &self.inner)
36            .finish()
37    }
38}
39
40impl<E> Strict<E>
41where
42    E: Expression,
43{
44    pub fn new(inner: E) -> Self {
45        Self { inner }
46    }
47}
48
49#[async_trait]
50impl<E> Expression for Strict<E>
51where
52    E: Expression,
53{
54    fn return_type(&self) -> DataType {
55        self.inner.return_type()
56    }
57
58    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
59        match self.inner.eval(input).await {
60            Err(ExprError::Multiple(_, errors)) => Err(errors.into_first()),
61            res => res,
62        }
63    }
64
65    async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
66        match self.inner.eval_v2(input).await {
67            Err(ExprError::Multiple(_, errors)) => Err(errors.into_first()),
68            res => res,
69        }
70    }
71
72    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
73        self.inner.eval_row(input).await
74    }
75
76    fn eval_const(&self) -> Result<Datum> {
77        self.inner.eval_const()
78    }
79
80    fn input_ref_index(&self) -> Option<usize> {
81        self.inner.input_ref_index()
82    }
83}