1use std::sync::Arc;
16
17use risingwave_common::array::{Array, ArrayRef, BoolArray, DataChunk};
18use risingwave_common::row::OwnedRow;
19use risingwave_common::types::{DataType, Datum, Scalar, ScalarRefImpl};
20use risingwave_common::util::iter_util::ZipEqFast;
21use risingwave_common::{bail, ensure};
22use risingwave_pb::expr::expr_node::{RexNode, Type};
23use risingwave_pb::expr::{ExprNode, FunctionCall};
24
25use super::build::get_children_and_return_type;
26use super::{BoxedExpression, Build, Expression};
27use crate::Result;
28
29#[derive(Debug)]
30pub struct SomeAllExpression {
31 left_expr: BoxedExpression,
32 right_expr: BoxedExpression,
33 expr_type: Type,
34 func: BoxedExpression,
35}
36
37impl SomeAllExpression {
38 pub fn new(
39 left_expr: BoxedExpression,
40 right_expr: BoxedExpression,
41 expr_type: Type,
42 func: BoxedExpression,
43 ) -> Self {
44 SomeAllExpression {
45 left_expr,
46 right_expr,
47 expr_type,
48 func,
49 }
50 }
51
52 fn resolve_bools(&self, bools: impl Iterator<Item = Option<bool>>) -> Option<bool> {
55 match self.expr_type {
56 Type::Some => {
57 let mut any_none = false;
58 for b in bools {
59 match b {
60 Some(true) => return Some(true),
61 Some(false) => continue,
62 None => any_none = true,
63 }
64 }
65 if any_none { None } else { Some(false) }
66 }
67 Type::All => {
68 let mut all_true = true;
69 for b in bools {
70 if b == Some(false) {
71 return Some(false);
72 }
73 if b != Some(true) {
74 all_true = false;
75 }
76 }
77 if all_true { Some(true) } else { None }
78 }
79 _ => unreachable!(),
80 }
81 }
82}
83
84#[async_trait::async_trait]
85impl Expression for SomeAllExpression {
86 fn return_type(&self) -> DataType {
87 DataType::Boolean
88 }
89
90 async fn eval(&self, data_chunk: &DataChunk) -> Result<ArrayRef> {
91 let arr_left = self.left_expr.eval(data_chunk).await?;
92 let arr_right = self.right_expr.eval(data_chunk).await?;
93 let mut num_array = Vec::with_capacity(data_chunk.capacity());
94
95 let arr_right_inner = arr_right.as_list();
96 let DataType::List(datatype) = arr_right_inner.data_type() else {
97 unreachable!()
98 };
99 let capacity = arr_right_inner.flatten().len();
100
101 let mut unfolded_arr_left_builder = arr_left.create_builder(capacity);
102 let mut unfolded_arr_right_builder = datatype.create_array_builder(capacity);
103
104 let mut unfolded_left_right =
105 |left: Option<ScalarRefImpl<'_>>,
106 right: Option<ScalarRefImpl<'_>>,
107 num_array: &mut Vec<Option<usize>>| {
108 if right.is_none() {
109 num_array.push(None);
110 return;
111 }
112
113 let array = right.unwrap().into_list();
114 let flattened = array.flatten();
115 let len = flattened.len();
116 num_array.push(Some(len));
117 unfolded_arr_left_builder.append_n(len, left);
118 for item in flattened.iter() {
119 unfolded_arr_right_builder.append(item);
120 }
121 };
122
123 if data_chunk.is_compacted() {
124 for (left, right) in arr_left.iter().zip_eq_fast(arr_right.iter()) {
125 unfolded_left_right(left, right, &mut num_array);
126 }
127 } else {
128 for ((left, right), visible) in arr_left
129 .iter()
130 .zip_eq_fast(arr_right.iter())
131 .zip_eq_fast(data_chunk.visibility().iter())
132 {
133 if !visible {
134 num_array.push(None);
135 continue;
136 }
137 unfolded_left_right(left, right, &mut num_array);
138 }
139 }
140
141 assert_eq!(num_array.len(), data_chunk.capacity());
142
143 let unfolded_arr_left = unfolded_arr_left_builder.finish();
144 let unfolded_arr_right = unfolded_arr_right_builder.finish();
145
146 assert_eq!(unfolded_arr_left.len(), unfolded_arr_right.len());
149 let unfolded_compact_len = unfolded_arr_left.len();
150
151 let data_chunk = DataChunk::new(
152 vec![unfolded_arr_left.into(), unfolded_arr_right.into()],
153 unfolded_compact_len,
154 );
155
156 let func_results = self.func.eval(&data_chunk).await?;
157 let bools = func_results.as_bool();
158 let mut offset = 0;
159 Ok(Arc::new(
160 num_array
161 .into_iter()
162 .map(|num| match num {
163 Some(num) => {
164 let range = offset..offset + num;
165 offset += num;
166 self.resolve_bools(range.map(|i| bools.value_at(i)))
167 }
168 None => None,
169 })
170 .collect::<BoolArray>()
171 .into(),
172 ))
173 }
174
175 async fn eval_row(&self, row: &OwnedRow) -> Result<Datum> {
176 let datum_left = self.left_expr.eval_row(row).await?;
177 let datum_right = self.right_expr.eval_row(row).await?;
178 let Some(array_right) = datum_right else {
179 return Ok(None);
180 };
181 let array_right = array_right.into_list().into_array();
182 let len = array_right.len();
183
184 let array_left = {
186 let mut builder = self.left_expr.return_type().create_array_builder(len);
187 builder.append_n(len, datum_left);
188 builder.finish().into_ref()
189 };
190
191 let chunk = DataChunk::new(vec![array_left, Arc::new(array_right)], len);
192 let bools = self.func.eval(&chunk).await?;
193
194 Ok(self
195 .resolve_bools(bools.as_bool().iter())
196 .map(|b| b.to_scalar_value()))
197 }
198}
199
200impl Build for SomeAllExpression {
201 fn build(
202 prost: &ExprNode,
203 build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
204 ) -> Result<Self> {
205 let outer_expr_type = prost.get_function_type().unwrap();
206 let (outer_children, outer_return_type) = get_children_and_return_type(prost)?;
207 ensure!(matches!(outer_return_type, DataType::Boolean));
208
209 let mut inner_expr_type = outer_children[0].get_function_type().unwrap();
210 let (mut inner_children, mut inner_return_type) =
211 get_children_and_return_type(&outer_children[0])?;
212 let mut stack = vec![];
213 while inner_children.len() != 2 {
214 stack.push((inner_expr_type, inner_return_type));
215 inner_expr_type = inner_children[0].get_function_type().unwrap();
216 (inner_children, inner_return_type) = get_children_and_return_type(&inner_children[0])?;
217 }
218
219 let left_expr = build_child(&inner_children[0])?;
220 let right_expr = build_child(&inner_children[1])?;
221
222 let DataType::List(right_expr_return_type) = right_expr.return_type() else {
223 bail!("Expect Array Type");
224 };
225
226 let eval_func = {
227 let left_expr_input_ref = ExprNode {
228 function_type: Type::Unspecified as i32,
229 return_type: Some(left_expr.return_type().to_protobuf()),
230 rex_node: Some(RexNode::InputRef(0)),
231 };
232 let right_expr_input_ref = ExprNode {
233 function_type: Type::Unspecified as i32,
234 return_type: Some(right_expr_return_type.to_protobuf()),
235 rex_node: Some(RexNode::InputRef(1)),
236 };
237 let mut root_expr_node = ExprNode {
238 function_type: inner_expr_type as i32,
239 return_type: Some(inner_return_type.to_protobuf()),
240 rex_node: Some(RexNode::FuncCall(FunctionCall {
241 children: vec![left_expr_input_ref, right_expr_input_ref],
242 })),
243 };
244 while let Some((expr_type, return_type)) = stack.pop() {
245 root_expr_node = ExprNode {
246 function_type: expr_type as i32,
247 return_type: Some(return_type.to_protobuf()),
248 rex_node: Some(RexNode::FuncCall(FunctionCall {
249 children: vec![root_expr_node],
250 })),
251 }
252 }
253 build_child(&root_expr_node)?
254 };
255
256 Ok(SomeAllExpression::new(
257 left_expr,
258 right_expr,
259 outer_expr_type,
260 eval_func,
261 ))
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use risingwave_common::row::Row;
268 use risingwave_common::test_prelude::DataChunkTestExt;
269 use risingwave_common::types::ToOwnedDatum;
270 use risingwave_common::util::iter_util::ZipEqDebug;
271 use risingwave_expr::expr::build_from_pretty;
272
273 use super::*;
274
275 #[tokio::test]
276 async fn test_some() {
277 let expr = SomeAllExpression::new(
278 build_from_pretty("0:int4"),
279 build_from_pretty("$0:boolean"),
280 Type::Some,
281 build_from_pretty("$1:boolean"),
282 );
283 let (input, expected) = DataChunk::from_pretty(
284 "B[] B
285 . .
286 {} f
287 {NULL} .
288 {NULL,f} .
289 {NULL,t} t
290 {t,f} t
291 {f,t} t", )
293 .split_column_at(1);
294
295 let output = expr.eval(&input).await.unwrap();
297 assert_eq!(&output, expected.column_at(0));
298
299 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
301 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
302 assert_eq!(result, expected.datum_at(0).to_owned_datum());
303 }
304 }
305
306 #[tokio::test]
307 async fn test_all() {
308 let expr = SomeAllExpression::new(
309 build_from_pretty("0:int4"),
310 build_from_pretty("$0:boolean"),
311 Type::All,
312 build_from_pretty("$1:boolean"),
313 );
314 let (input, expected) = DataChunk::from_pretty(
315 "B[] B
316 . .
317 {} t
318 {NULL} .
319 {NULL,t} .
320 {NULL,f} f
321 {f,f} f
322 {t} t", )
324 .split_column_at(1);
325
326 let output = expr.eval(&input).await.unwrap();
328 assert_eq!(&output, expected.column_at(0));
329
330 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
332 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
333 assert_eq!(result, expected.datum_at(0).to_owned_datum());
334 }
335 }
336}