risingwave_expr_impl/scalar/
in_.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::Debug;
17use std::sync::Arc;
18
19use futures_util::FutureExt;
20use risingwave_common::array::{ArrayBuilder, ArrayRef, BoolArrayBuilder, DataChunk};
21use risingwave_common::row::OwnedRow;
22use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum};
23use risingwave_common::util::iter_util::ZipEqFast;
24use risingwave_expr::expr::{BoxedExpression, Expression};
25use risingwave_expr::{Result, build_function};
26
27#[derive(Debug)]
28pub struct InExpression {
29    left: BoxedExpression,
30    set: HashSet<Datum>,
31    return_type: DataType,
32}
33
34impl InExpression {
35    pub fn new(
36        left: BoxedExpression,
37        data: impl Iterator<Item = Datum>,
38        return_type: DataType,
39    ) -> Self {
40        Self {
41            left,
42            set: data.collect(),
43            return_type,
44        }
45    }
46
47    // Returns true if datum exists in set, null if datum is null or datum does not exist in set
48    // but null does, and false if neither datum nor null exists in set.
49    fn exists(&self, datum: &Datum) -> Option<bool> {
50        if datum.is_none() {
51            None
52        } else if self.set.contains(datum) {
53            Some(true)
54        } else if self.set.contains(&None) {
55            None
56        } else {
57            Some(false)
58        }
59    }
60}
61
62#[async_trait::async_trait]
63impl Expression for InExpression {
64    fn return_type(&self) -> DataType {
65        self.return_type.clone()
66    }
67
68    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
69        let input_array = self.left.eval(input).await?;
70        let mut output_array = BoolArrayBuilder::new(input_array.len());
71        for (data, vis) in input_array.iter().zip_eq_fast(input.visibility().iter()) {
72            if vis {
73                // TODO: avoid `to_owned_datum()`
74                let ret = self.exists(&data.to_owned_datum());
75                output_array.append(ret);
76            } else {
77                output_array.append(None);
78            }
79        }
80        Ok(Arc::new(output_array.finish().into()))
81    }
82
83    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
84        let data = self.left.eval_row(input).await?;
85        let ret = self.exists(&data);
86        Ok(ret.map(|b| b.to_scalar_value()))
87    }
88}
89
90#[build_function("in(any, ...) -> boolean")]
91fn build(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
92    let mut iter = children.into_iter();
93    let left_expr = iter.next().unwrap();
94    let mut data = Vec::with_capacity(iter.size_hint().0);
95    let data_chunk = DataChunk::new_dummy(1);
96    for child in iter {
97        let array = child
98            .eval(&data_chunk)
99            .now_or_never()
100            .expect("constant expression should not be async")?;
101        let datum = array.value_at(0).to_owned_datum();
102        data.push(datum);
103    }
104    Ok(Box::new(InExpression::new(
105        left_expr,
106        data.into_iter(),
107        return_type,
108    )))
109}
110
111#[cfg(test)]
112mod tests {
113    use risingwave_common::array::DataChunk;
114    use risingwave_common::row::Row;
115    use risingwave_common::test_prelude::DataChunkTestExt;
116    use risingwave_common::types::ToOwnedDatum;
117    use risingwave_common::util::iter_util::ZipEqDebug;
118    use risingwave_expr::expr::{Expression, build_from_pretty};
119
120    #[tokio::test]
121    async fn test_in_expr() {
122        let expr = build_from_pretty("(in:boolean $0:varchar abc:varchar def:varchar)");
123        let (input, expected) = DataChunk::from_pretty(
124            "T   B
125             abc t
126             a   f
127             def t
128             abc t
129             .   .",
130        )
131        .split_column_at(1);
132
133        // test eval
134        let output = expr.eval(&input).await.unwrap();
135        assert_eq!(&output, expected.column_at(0));
136
137        // test eval_row
138        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
139            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
140            assert_eq!(result, expected.datum_at(0).to_owned_datum());
141        }
142    }
143
144    #[tokio::test]
145    async fn test_in_expr_null() {
146        let expr = build_from_pretty("(in:boolean $0:varchar abc:varchar null:varchar)");
147        let (input, expected) = DataChunk::from_pretty(
148            "T   B
149             abc t
150             a   .
151             .   .",
152        )
153        .split_column_at(1);
154
155        // test eval
156        let output = expr.eval(&input).await.unwrap();
157        assert_eq!(&output, expected.column_at(0));
158
159        // test eval_row
160        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
161            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
162            assert_eq!(result, expected.datum_at(0).to_owned_datum());
163        }
164    }
165}