1use std::sync::Arc;
16
17use risingwave_common::array::{Array, ArrayRef, BoolArray, DataChunk};
18use risingwave_common::row::OwnedRow;
19use risingwave_common::types::{DataType, Datum, Scalar, ScalarRefImpl};
20use risingwave_common::util::iter_util::ZipEqFast;
21use risingwave_common::{bail, ensure};
22use risingwave_pb::expr::expr_node::{RexNode, Type};
23use risingwave_pb::expr::{ExprNode, FunctionCall};
24
25use super::build::get_children_and_return_type;
26use super::{BoxedExpression, Build, Expression};
27use crate::Result;
28
29#[derive(Debug)]
30pub struct SomeAllExpression {
31 left_expr: BoxedExpression,
32 right_expr: BoxedExpression,
33 expr_type: Type,
34 func: BoxedExpression,
35}
36
37impl SomeAllExpression {
38 pub fn new(
39 left_expr: BoxedExpression,
40 right_expr: BoxedExpression,
41 expr_type: Type,
42 func: BoxedExpression,
43 ) -> Self {
44 SomeAllExpression {
45 left_expr,
46 right_expr,
47 expr_type,
48 func,
49 }
50 }
51
52 fn resolve_bools(&self, bools: impl Iterator<Item = Option<bool>>) -> Option<bool> {
55 match self.expr_type {
56 Type::Some => {
57 let mut any_none = false;
58 for b in bools {
59 match b {
60 Some(true) => return Some(true),
61 Some(false) => continue,
62 None => any_none = true,
63 }
64 }
65 if any_none { None } else { Some(false) }
66 }
67 Type::All => {
68 let mut all_true = true;
69 for b in bools {
70 if b == Some(false) {
71 return Some(false);
72 }
73 if b != Some(true) {
74 all_true = false;
75 }
76 }
77 if all_true { Some(true) } else { None }
78 }
79 _ => unreachable!(),
80 }
81 }
82}
83
84#[async_trait::async_trait]
85impl Expression for SomeAllExpression {
86 fn return_type(&self) -> DataType {
87 DataType::Boolean
88 }
89
90 async fn eval(&self, data_chunk: &DataChunk) -> Result<ArrayRef> {
91 let arr_left = self.left_expr.eval(data_chunk).await?;
92 let arr_right = self.right_expr.eval(data_chunk).await?;
93 let mut num_array = Vec::with_capacity(data_chunk.capacity());
94
95 let arr_right_inner = arr_right.as_list();
96 let elem_type = arr_right_inner.data_type().into_list_element_type();
97 let capacity = arr_right_inner.flatten().len();
98
99 let mut unfolded_arr_left_builder = arr_left.create_builder(capacity);
100 let mut unfolded_arr_right_builder = elem_type.create_array_builder(capacity);
101
102 let mut unfolded_left_right =
103 |left: Option<ScalarRefImpl<'_>>,
104 right: Option<ScalarRefImpl<'_>>,
105 num_array: &mut Vec<Option<usize>>| {
106 if right.is_none() {
107 num_array.push(None);
108 return;
109 }
110
111 let array = right.unwrap().into_list();
112 let flattened = array.flatten();
113 let len = flattened.len();
114 num_array.push(Some(len));
115 unfolded_arr_left_builder.append_n(len, left);
116 for item in flattened.iter() {
117 unfolded_arr_right_builder.append(item);
118 }
119 };
120
121 if data_chunk.is_compacted() {
122 for (left, right) in arr_left.iter().zip_eq_fast(arr_right.iter()) {
123 unfolded_left_right(left, right, &mut num_array);
124 }
125 } else {
126 for ((left, right), visible) in arr_left
127 .iter()
128 .zip_eq_fast(arr_right.iter())
129 .zip_eq_fast(data_chunk.visibility().iter())
130 {
131 if !visible {
132 num_array.push(None);
133 continue;
134 }
135 unfolded_left_right(left, right, &mut num_array);
136 }
137 }
138
139 assert_eq!(num_array.len(), data_chunk.capacity());
140
141 let unfolded_arr_left = unfolded_arr_left_builder.finish();
142 let unfolded_arr_right = unfolded_arr_right_builder.finish();
143
144 assert_eq!(unfolded_arr_left.len(), unfolded_arr_right.len());
147 let unfolded_compact_len = unfolded_arr_left.len();
148
149 let data_chunk = DataChunk::new(
150 vec![unfolded_arr_left.into(), unfolded_arr_right.into()],
151 unfolded_compact_len,
152 );
153
154 let func_results = self.func.eval(&data_chunk).await?;
155 let bools = func_results.as_bool();
156 let mut offset = 0;
157 Ok(Arc::new(
158 num_array
159 .into_iter()
160 .map(|num| match num {
161 Some(num) => {
162 let range = offset..offset + num;
163 offset += num;
164 self.resolve_bools(range.map(|i| bools.value_at(i)))
165 }
166 None => None,
167 })
168 .collect::<BoolArray>()
169 .into(),
170 ))
171 }
172
173 async fn eval_row(&self, row: &OwnedRow) -> Result<Datum> {
174 let datum_left = self.left_expr.eval_row(row).await?;
175 let datum_right = self.right_expr.eval_row(row).await?;
176 let Some(array_right) = datum_right else {
177 return Ok(None);
178 };
179 let array_right = array_right.into_list().into_array();
180 let len = array_right.len();
181
182 let array_left = {
184 let mut builder = self.left_expr.return_type().create_array_builder(len);
185 builder.append_n(len, datum_left);
186 builder.finish().into_ref()
187 };
188
189 let chunk = DataChunk::new(vec![array_left, Arc::new(array_right)], len);
190 let bools = self.func.eval(&chunk).await?;
191
192 Ok(self
193 .resolve_bools(bools.as_bool().iter())
194 .map(|b| b.to_scalar_value()))
195 }
196}
197
198impl Build for SomeAllExpression {
199 fn build(
200 prost: &ExprNode,
201 build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
202 ) -> Result<Self> {
203 let outer_expr_type = prost.get_function_type().unwrap();
204 let (outer_children, outer_return_type) = get_children_and_return_type(prost)?;
205 ensure!(matches!(outer_return_type, DataType::Boolean));
206
207 let mut inner_expr_type = outer_children[0].get_function_type().unwrap();
208 let (mut inner_children, mut inner_return_type) =
209 get_children_and_return_type(&outer_children[0])?;
210 let mut stack = vec![];
211 while inner_children.len() != 2 {
212 stack.push((inner_expr_type, inner_return_type));
213 inner_expr_type = inner_children[0].get_function_type().unwrap();
214 (inner_children, inner_return_type) = get_children_and_return_type(&inner_children[0])?;
215 }
216
217 let left_expr = build_child(&inner_children[0])?;
218 let right_expr = build_child(&inner_children[1])?;
219
220 let DataType::List(right_list_type) = right_expr.return_type() else {
221 bail!("Expect Array Type");
222 };
223 let right_expr_return_type = right_list_type.into_elem();
224
225 let eval_func = {
226 let left_expr_input_ref = ExprNode {
227 function_type: Type::Unspecified as i32,
228 return_type: Some(left_expr.return_type().to_protobuf()),
229 rex_node: Some(RexNode::InputRef(0)),
230 };
231 let right_expr_input_ref = ExprNode {
232 function_type: Type::Unspecified as i32,
233 return_type: Some(right_expr_return_type.to_protobuf()),
234 rex_node: Some(RexNode::InputRef(1)),
235 };
236 let mut root_expr_node = ExprNode {
237 function_type: inner_expr_type as i32,
238 return_type: Some(inner_return_type.to_protobuf()),
239 rex_node: Some(RexNode::FuncCall(FunctionCall {
240 children: vec![left_expr_input_ref, right_expr_input_ref],
241 })),
242 };
243 while let Some((expr_type, return_type)) = stack.pop() {
244 root_expr_node = ExprNode {
245 function_type: expr_type as i32,
246 return_type: Some(return_type.to_protobuf()),
247 rex_node: Some(RexNode::FuncCall(FunctionCall {
248 children: vec![root_expr_node],
249 })),
250 }
251 }
252 build_child(&root_expr_node)?
253 };
254
255 Ok(SomeAllExpression::new(
256 left_expr,
257 right_expr,
258 outer_expr_type,
259 eval_func,
260 ))
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use risingwave_common::row::Row;
267 use risingwave_common::test_prelude::DataChunkTestExt;
268 use risingwave_common::types::ToOwnedDatum;
269 use risingwave_common::util::iter_util::ZipEqDebug;
270 use risingwave_expr::expr::build_from_pretty;
271
272 use super::*;
273
274 #[tokio::test]
275 async fn test_some() {
276 let expr = SomeAllExpression::new(
277 build_from_pretty("0:int4"),
278 build_from_pretty("$0:boolean"),
279 Type::Some,
280 build_from_pretty("$1:boolean"),
281 );
282 let (input, expected) = DataChunk::from_pretty(
283 "B[] B
284 . .
285 {} f
286 {NULL} .
287 {NULL,f} .
288 {NULL,t} t
289 {t,f} t
290 {f,t} t", )
292 .split_column_at(1);
293
294 let output = expr.eval(&input).await.unwrap();
296 assert_eq!(&output, expected.column_at(0));
297
298 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
300 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
301 assert_eq!(result, expected.datum_at(0).to_owned_datum());
302 }
303 }
304
305 #[tokio::test]
306 async fn test_all() {
307 let expr = SomeAllExpression::new(
308 build_from_pretty("0:int4"),
309 build_from_pretty("$0:boolean"),
310 Type::All,
311 build_from_pretty("$1:boolean"),
312 );
313 let (input, expected) = DataChunk::from_pretty(
314 "B[] B
315 . .
316 {} t
317 {NULL} .
318 {NULL,t} .
319 {NULL,f} f
320 {f,f} f
321 {t} t", )
323 .split_column_at(1);
324
325 let output = expr.eval(&input).await.unwrap();
327 assert_eq!(&output, expected.column_at(0));
328
329 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
331 let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
332 assert_eq!(result, expected.datum_at(0).to_owned_datum());
333 }
334 }
335}