risingwave_expr_impl/scalar/
in_.rs1use std::collections::HashSet;
16use std::fmt::Debug;
17use std::sync::Arc;
18
19use futures_util::FutureExt;
20use risingwave_common::array::{ArrayBuilder, ArrayRef, BoolArrayBuilder, DataChunk};
21use risingwave_common::row::OwnedRow;
22use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum};
23use risingwave_common::util::iter_util::ZipEqFast;
24use risingwave_expr::expr::{BoxedExpression, Expression};
25use risingwave_expr::{Result, build_function};
26
27#[derive(Debug)]
28pub struct InExpression {
29 left: BoxedExpression,
30 set: HashSet<Datum>,
31 return_type: DataType,
32}
33
34impl InExpression {
35 pub fn new(
36 left: BoxedExpression,
37 data: impl Iterator<Item = Datum>,
38 return_type: DataType,
39 ) -> Self {
40 Self {
41 left,
42 set: data.collect(),
43 return_type,
44 }
45 }
46
47 fn exists(&self, datum: &Datum) -> Option<bool> {
50 if datum.is_none() {
51 None
52 } else if self.set.contains(datum) {
53 Some(true)
54 } else if self.set.contains(&None) {
55 None
56 } else {
57 Some(false)
58 }
59 }
60}
61
62#[async_trait::async_trait]
63impl Expression for InExpression {
64 fn return_type(&self) -> DataType {
65 self.return_type.clone()
66 }
67
68 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
69 let input_array = self.left.eval(input).await?;
70 let mut output_array = BoolArrayBuilder::new(input_array.len());
71 for (data, vis) in input_array.iter().zip_eq_fast(input.visibility().iter()) {
72 if vis {
73 let ret = self.exists(&data.to_owned_datum());
75 output_array.append(ret);
76 } else {
77 output_array.append(None);
78 }
79 }
80 Ok(Arc::new(output_array.finish().into()))
81 }
82
83 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
84 let data = self.left.eval_row(input).await?;
85 let ret = self.exists(&data);
86 Ok(ret.map(|b| b.to_scalar_value()))
87 }
88}
89
90#[build_function("in(any, ...) -> boolean")]
91fn build(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
92 let mut iter = children.into_iter();
93 let left_expr = iter.next().unwrap();
94 let mut data = Vec::with_capacity(iter.size_hint().0);
95 let data_chunk = DataChunk::new_dummy(1);
96 for child in iter {
97 let array = child
98 .eval(&data_chunk)
99 .now_or_never()
100 .expect("constant expression should not be async")?;
101 let datum = array.value_at(0).to_owned_datum();
102 data.push(datum);
103 }
104 Ok(Box::new(InExpression::new(
105 left_expr,
106 data.into_iter(),
107 return_type,
108 )))
109}
110
111#[cfg(test)]
112mod tests {
113 use risingwave_common::array::DataChunk;
114 use risingwave_common::row::Row;
115 use risingwave_common::test_prelude::DataChunkTestExt;
116 use risingwave_common::types::ToOwnedDatum;
117 use risingwave_common::util::iter_util::ZipEqDebug;
118 use risingwave_expr::expr::{Expression, build_from_pretty};
119
120 #[tokio::test]
121 async fn test_in_expr() {
122 let expr = build_from_pretty("(in:boolean $0:varchar abc:varchar def:varchar)");
123 let (input, expected) = DataChunk::from_pretty(
124 "T B
125 abc t
126 a f
127 def t
128 abc t
129 . .",
130 )
131 .split_column_at(1);
132
133 let output = expr.eval(&input).await.unwrap();
135 assert_eq!(&output, expected.column_at(0));
136
137 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
139 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
140 assert_eq!(result, expected.datum_at(0).to_owned_datum());
141 }
142 }
143
144 #[tokio::test]
145 async fn test_in_expr_null() {
146 let expr = build_from_pretty("(in:boolean $0:varchar abc:varchar null:varchar)");
147 let (input, expected) = DataChunk::from_pretty(
148 "T B
149 abc t
150 a .
151 . .",
152 )
153 .split_column_at(1);
154
155 let output = expr.eval(&input).await.unwrap();
157 assert_eq!(&output, expected.column_at(0));
158
159 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
161 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
162 assert_eq!(result, expected.datum_at(0).to_owned_datum());
163 }
164 }
165}