risingwave_expr/expr/
expr_some_all.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::array::{Array, ArrayRef, BoolArray, DataChunk};
18use risingwave_common::row::OwnedRow;
19use risingwave_common::types::{DataType, Datum, Scalar, ScalarRefImpl};
20use risingwave_common::util::iter_util::ZipEqFast;
21use risingwave_common::{bail, ensure};
22use risingwave_pb::expr::expr_node::{RexNode, Type};
23use risingwave_pb::expr::{ExprNode, FunctionCall};
24
25use super::build::get_children_and_return_type;
26use super::{BoxedExpression, Build, Expression};
27use crate::Result;
28
29#[derive(Debug)]
30pub struct SomeAllExpression {
31    left_expr: BoxedExpression,
32    right_expr: BoxedExpression,
33    expr_type: Type,
34    func: BoxedExpression,
35}
36
37impl SomeAllExpression {
38    pub fn new(
39        left_expr: BoxedExpression,
40        right_expr: BoxedExpression,
41        expr_type: Type,
42        func: BoxedExpression,
43    ) -> Self {
44        SomeAllExpression {
45            left_expr,
46            right_expr,
47            expr_type,
48            func,
49        }
50    }
51
52    // Notice that this function may not exhaust the iterator,
53    // so never pass an iterator created `by_ref`.
54    fn resolve_bools(&self, bools: impl Iterator<Item = Option<bool>>) -> Option<bool> {
55        match self.expr_type {
56            Type::Some => {
57                let mut any_none = false;
58                for b in bools {
59                    match b {
60                        Some(true) => return Some(true),
61                        Some(false) => continue,
62                        None => any_none = true,
63                    }
64                }
65                if any_none { None } else { Some(false) }
66            }
67            Type::All => {
68                let mut all_true = true;
69                for b in bools {
70                    if b == Some(false) {
71                        return Some(false);
72                    }
73                    if b != Some(true) {
74                        all_true = false;
75                    }
76                }
77                if all_true { Some(true) } else { None }
78            }
79            _ => unreachable!(),
80        }
81    }
82}
83
84#[async_trait::async_trait]
85impl Expression for SomeAllExpression {
86    fn return_type(&self) -> DataType {
87        DataType::Boolean
88    }
89
90    async fn eval(&self, data_chunk: &DataChunk) -> Result<ArrayRef> {
91        let arr_left = self.left_expr.eval(data_chunk).await?;
92        let arr_right = self.right_expr.eval(data_chunk).await?;
93        let mut num_array = Vec::with_capacity(data_chunk.capacity());
94
95        let arr_right_inner = arr_right.as_list();
96        let elem_type = arr_right_inner.data_type().into_list_element_type();
97        let capacity = arr_right_inner.flatten().len();
98
99        let mut unfolded_arr_left_builder = arr_left.create_builder(capacity);
100        let mut unfolded_arr_right_builder = elem_type.create_array_builder(capacity);
101
102        let mut unfolded_left_right =
103            |left: Option<ScalarRefImpl<'_>>,
104             right: Option<ScalarRefImpl<'_>>,
105             num_array: &mut Vec<Option<usize>>| {
106                if right.is_none() {
107                    num_array.push(None);
108                    return;
109                }
110
111                let array = right.unwrap().into_list();
112                let flattened = array.flatten();
113                let len = flattened.len();
114                num_array.push(Some(len));
115                unfolded_arr_left_builder.append_n(len, left);
116                for item in flattened.iter() {
117                    unfolded_arr_right_builder.append(item);
118                }
119            };
120
121        if data_chunk.is_compacted() {
122            for (left, right) in arr_left.iter().zip_eq_fast(arr_right.iter()) {
123                unfolded_left_right(left, right, &mut num_array);
124            }
125        } else {
126            for ((left, right), visible) in arr_left
127                .iter()
128                .zip_eq_fast(arr_right.iter())
129                .zip_eq_fast(data_chunk.visibility().iter())
130            {
131                if !visible {
132                    num_array.push(None);
133                    continue;
134                }
135                unfolded_left_right(left, right, &mut num_array);
136            }
137        }
138
139        assert_eq!(num_array.len(), data_chunk.capacity());
140
141        let unfolded_arr_left = unfolded_arr_left_builder.finish();
142        let unfolded_arr_right = unfolded_arr_right_builder.finish();
143
144        // Unfolded array are actually compacted, and the visibility of the output array will be
145        // further restored by `num_array`.
146        assert_eq!(unfolded_arr_left.len(), unfolded_arr_right.len());
147        let unfolded_compact_len = unfolded_arr_left.len();
148
149        let data_chunk = DataChunk::new(
150            vec![unfolded_arr_left.into(), unfolded_arr_right.into()],
151            unfolded_compact_len,
152        );
153
154        let func_results = self.func.eval(&data_chunk).await?;
155        let bools = func_results.as_bool();
156        let mut offset = 0;
157        Ok(Arc::new(
158            num_array
159                .into_iter()
160                .map(|num| match num {
161                    Some(num) => {
162                        let range = offset..offset + num;
163                        offset += num;
164                        self.resolve_bools(range.map(|i| bools.value_at(i)))
165                    }
166                    None => None,
167                })
168                .collect::<BoolArray>()
169                .into(),
170        ))
171    }
172
173    async fn eval_row(&self, row: &OwnedRow) -> Result<Datum> {
174        let datum_left = self.left_expr.eval_row(row).await?;
175        let datum_right = self.right_expr.eval_row(row).await?;
176        let Some(array_right) = datum_right else {
177            return Ok(None);
178        };
179        let array_right = array_right.into_list().into_array();
180        let len = array_right.len();
181
182        // expand left to array
183        let array_left = {
184            let mut builder = self.left_expr.return_type().create_array_builder(len);
185            builder.append_n(len, datum_left);
186            builder.finish().into_ref()
187        };
188
189        let chunk = DataChunk::new(vec![array_left, Arc::new(array_right)], len);
190        let bools = self.func.eval(&chunk).await?;
191
192        Ok(self
193            .resolve_bools(bools.as_bool().iter())
194            .map(|b| b.to_scalar_value()))
195    }
196}
197
198impl Build for SomeAllExpression {
199    fn build(
200        prost: &ExprNode,
201        build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
202    ) -> Result<Self> {
203        let outer_expr_type = prost.get_function_type().unwrap();
204        let (outer_children, outer_return_type) = get_children_and_return_type(prost)?;
205        ensure!(matches!(outer_return_type, DataType::Boolean));
206
207        let mut inner_expr_type = outer_children[0].get_function_type().unwrap();
208        let (mut inner_children, mut inner_return_type) =
209            get_children_and_return_type(&outer_children[0])?;
210        let mut stack = vec![];
211        while inner_children.len() != 2 {
212            stack.push((inner_expr_type, inner_return_type));
213            inner_expr_type = inner_children[0].get_function_type().unwrap();
214            (inner_children, inner_return_type) = get_children_and_return_type(&inner_children[0])?;
215        }
216
217        let left_expr = build_child(&inner_children[0])?;
218        let right_expr = build_child(&inner_children[1])?;
219
220        let DataType::List(right_list_type) = right_expr.return_type() else {
221            bail!("Expect Array Type");
222        };
223        let right_expr_return_type = right_list_type.into_elem();
224
225        let eval_func = {
226            let left_expr_input_ref = ExprNode {
227                function_type: Type::Unspecified as i32,
228                return_type: Some(left_expr.return_type().to_protobuf()),
229                rex_node: Some(RexNode::InputRef(0)),
230            };
231            let right_expr_input_ref = ExprNode {
232                function_type: Type::Unspecified as i32,
233                return_type: Some(right_expr_return_type.to_protobuf()),
234                rex_node: Some(RexNode::InputRef(1)),
235            };
236            let mut root_expr_node = ExprNode {
237                function_type: inner_expr_type as i32,
238                return_type: Some(inner_return_type.to_protobuf()),
239                rex_node: Some(RexNode::FuncCall(FunctionCall {
240                    children: vec![left_expr_input_ref, right_expr_input_ref],
241                })),
242            };
243            while let Some((expr_type, return_type)) = stack.pop() {
244                root_expr_node = ExprNode {
245                    function_type: expr_type as i32,
246                    return_type: Some(return_type.to_protobuf()),
247                    rex_node: Some(RexNode::FuncCall(FunctionCall {
248                        children: vec![root_expr_node],
249                    })),
250                }
251            }
252            build_child(&root_expr_node)?
253        };
254
255        Ok(SomeAllExpression::new(
256            left_expr,
257            right_expr,
258            outer_expr_type,
259            eval_func,
260        ))
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use risingwave_common::row::Row;
267    use risingwave_common::test_prelude::DataChunkTestExt;
268    use risingwave_common::types::ToOwnedDatum;
269    use risingwave_common::util::iter_util::ZipEqDebug;
270    use risingwave_expr::expr::build_from_pretty;
271
272    use super::*;
273
274    #[tokio::test]
275    async fn test_some() {
276        let expr = SomeAllExpression::new(
277            build_from_pretty("0:int4"),
278            build_from_pretty("$0:boolean"),
279            Type::Some,
280            build_from_pretty("$1:boolean"),
281        );
282        let (input, expected) = DataChunk::from_pretty(
283            "B[]        B
284             .          .
285             {}         f
286             {NULL}     .
287             {NULL,f}   .
288             {NULL,t}   t
289             {t,f}      t
290             {f,t}      t", // <- regression test for #14214
291        )
292        .split_column_at(1);
293
294        // test eval
295        let output = expr.eval(&input).await.unwrap();
296        assert_eq!(&output, expected.column_at(0));
297
298        // test eval_row
299        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
300            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
301            assert_eq!(result, expected.datum_at(0).to_owned_datum());
302        }
303    }
304
305    #[tokio::test]
306    async fn test_all() {
307        let expr = SomeAllExpression::new(
308            build_from_pretty("0:int4"),
309            build_from_pretty("$0:boolean"),
310            Type::All,
311            build_from_pretty("$1:boolean"),
312        );
313        let (input, expected) = DataChunk::from_pretty(
314            "B[]        B
315             .          .
316             {}         t
317             {NULL}     .
318             {NULL,t}   .
319             {NULL,f}   f
320             {f,f}      f
321             {t}        t", // <- regression test for #14214
322        )
323        .split_column_at(1);
324
325        // test eval
326        let output = expr.eval(&input).await.unwrap();
327        assert_eq!(&output, expected.column_at(0));
328
329        // test eval_row
330        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
331            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
332            assert_eq!(result, expected.datum_at(0).to_owned_datum());
333        }
334    }
335}