risingwave_expr/expr/
expr_some_all.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::array::{Array, ArrayRef, BoolArray, DataChunk};
18use risingwave_common::row::OwnedRow;
19use risingwave_common::types::{DataType, Datum, Scalar, ScalarRefImpl};
20use risingwave_common::util::iter_util::ZipEqFast;
21use risingwave_common::{bail, ensure};
22use risingwave_pb::expr::expr_node::{RexNode, Type};
23use risingwave_pb::expr::{ExprNode, FunctionCall};
24
25use super::build::get_children_and_return_type;
26use super::{BoxedExpression, Build, Expression};
27use crate::Result;
28
29#[derive(Debug)]
30pub struct SomeAllExpression {
31    left_expr: BoxedExpression,
32    right_expr: BoxedExpression,
33    expr_type: Type,
34    func: BoxedExpression,
35}
36
37impl SomeAllExpression {
38    pub fn new(
39        left_expr: BoxedExpression,
40        right_expr: BoxedExpression,
41        expr_type: Type,
42        func: BoxedExpression,
43    ) -> Self {
44        SomeAllExpression {
45            left_expr,
46            right_expr,
47            expr_type,
48            func,
49        }
50    }
51
52    // Notice that this function may not exhaust the iterator,
53    // so never pass an iterator created `by_ref`.
54    fn resolve_bools(&self, bools: impl Iterator<Item = Option<bool>>) -> Option<bool> {
55        match self.expr_type {
56            Type::Some => {
57                let mut any_none = false;
58                for b in bools {
59                    match b {
60                        Some(true) => return Some(true),
61                        Some(false) => continue,
62                        None => any_none = true,
63                    }
64                }
65                if any_none { None } else { Some(false) }
66            }
67            Type::All => {
68                let mut all_true = true;
69                for b in bools {
70                    if b == Some(false) {
71                        return Some(false);
72                    }
73                    if b != Some(true) {
74                        all_true = false;
75                    }
76                }
77                if all_true { Some(true) } else { None }
78            }
79            _ => unreachable!(),
80        }
81    }
82}
83
84#[async_trait::async_trait]
85impl Expression for SomeAllExpression {
86    fn return_type(&self) -> DataType {
87        DataType::Boolean
88    }
89
90    async fn eval(&self, data_chunk: &DataChunk) -> Result<ArrayRef> {
91        let arr_left = self.left_expr.eval(data_chunk).await?;
92        let arr_right = self.right_expr.eval(data_chunk).await?;
93        let mut num_array = Vec::with_capacity(data_chunk.capacity());
94
95        let arr_right_inner = arr_right.as_list();
96        let DataType::List(datatype) = arr_right_inner.data_type() else {
97            unreachable!()
98        };
99        let capacity = arr_right_inner.flatten().len();
100
101        let mut unfolded_arr_left_builder = arr_left.create_builder(capacity);
102        let mut unfolded_arr_right_builder = datatype.create_array_builder(capacity);
103
104        let mut unfolded_left_right =
105            |left: Option<ScalarRefImpl<'_>>,
106             right: Option<ScalarRefImpl<'_>>,
107             num_array: &mut Vec<Option<usize>>| {
108                if right.is_none() {
109                    num_array.push(None);
110                    return;
111                }
112
113                let array = right.unwrap().into_list();
114                let flattened = array.flatten();
115                let len = flattened.len();
116                num_array.push(Some(len));
117                unfolded_arr_left_builder.append_n(len, left);
118                for item in flattened.iter() {
119                    unfolded_arr_right_builder.append(item);
120                }
121            };
122
123        if data_chunk.is_compacted() {
124            for (left, right) in arr_left.iter().zip_eq_fast(arr_right.iter()) {
125                unfolded_left_right(left, right, &mut num_array);
126            }
127        } else {
128            for ((left, right), visible) in arr_left
129                .iter()
130                .zip_eq_fast(arr_right.iter())
131                .zip_eq_fast(data_chunk.visibility().iter())
132            {
133                if !visible {
134                    num_array.push(None);
135                    continue;
136                }
137                unfolded_left_right(left, right, &mut num_array);
138            }
139        }
140
141        assert_eq!(num_array.len(), data_chunk.capacity());
142
143        let unfolded_arr_left = unfolded_arr_left_builder.finish();
144        let unfolded_arr_right = unfolded_arr_right_builder.finish();
145
146        // Unfolded array are actually compacted, and the visibility of the output array will be
147        // further restored by `num_array`.
148        assert_eq!(unfolded_arr_left.len(), unfolded_arr_right.len());
149        let unfolded_compact_len = unfolded_arr_left.len();
150
151        let data_chunk = DataChunk::new(
152            vec![unfolded_arr_left.into(), unfolded_arr_right.into()],
153            unfolded_compact_len,
154        );
155
156        let func_results = self.func.eval(&data_chunk).await?;
157        let bools = func_results.as_bool();
158        let mut offset = 0;
159        Ok(Arc::new(
160            num_array
161                .into_iter()
162                .map(|num| match num {
163                    Some(num) => {
164                        let range = offset..offset + num;
165                        offset += num;
166                        self.resolve_bools(range.map(|i| bools.value_at(i)))
167                    }
168                    None => None,
169                })
170                .collect::<BoolArray>()
171                .into(),
172        ))
173    }
174
175    async fn eval_row(&self, row: &OwnedRow) -> Result<Datum> {
176        let datum_left = self.left_expr.eval_row(row).await?;
177        let datum_right = self.right_expr.eval_row(row).await?;
178        let Some(array_right) = datum_right else {
179            return Ok(None);
180        };
181        let array_right = array_right.into_list().into_array();
182        let len = array_right.len();
183
184        // expand left to array
185        let array_left = {
186            let mut builder = self.left_expr.return_type().create_array_builder(len);
187            builder.append_n(len, datum_left);
188            builder.finish().into_ref()
189        };
190
191        let chunk = DataChunk::new(vec![array_left, Arc::new(array_right)], len);
192        let bools = self.func.eval(&chunk).await?;
193
194        Ok(self
195            .resolve_bools(bools.as_bool().iter())
196            .map(|b| b.to_scalar_value()))
197    }
198}
199
200impl Build for SomeAllExpression {
201    fn build(
202        prost: &ExprNode,
203        build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
204    ) -> Result<Self> {
205        let outer_expr_type = prost.get_function_type().unwrap();
206        let (outer_children, outer_return_type) = get_children_and_return_type(prost)?;
207        ensure!(matches!(outer_return_type, DataType::Boolean));
208
209        let mut inner_expr_type = outer_children[0].get_function_type().unwrap();
210        let (mut inner_children, mut inner_return_type) =
211            get_children_and_return_type(&outer_children[0])?;
212        let mut stack = vec![];
213        while inner_children.len() != 2 {
214            stack.push((inner_expr_type, inner_return_type));
215            inner_expr_type = inner_children[0].get_function_type().unwrap();
216            (inner_children, inner_return_type) = get_children_and_return_type(&inner_children[0])?;
217        }
218
219        let left_expr = build_child(&inner_children[0])?;
220        let right_expr = build_child(&inner_children[1])?;
221
222        let DataType::List(right_expr_return_type) = right_expr.return_type() else {
223            bail!("Expect Array Type");
224        };
225
226        let eval_func = {
227            let left_expr_input_ref = ExprNode {
228                function_type: Type::Unspecified as i32,
229                return_type: Some(left_expr.return_type().to_protobuf()),
230                rex_node: Some(RexNode::InputRef(0)),
231            };
232            let right_expr_input_ref = ExprNode {
233                function_type: Type::Unspecified as i32,
234                return_type: Some(right_expr_return_type.to_protobuf()),
235                rex_node: Some(RexNode::InputRef(1)),
236            };
237            let mut root_expr_node = ExprNode {
238                function_type: inner_expr_type as i32,
239                return_type: Some(inner_return_type.to_protobuf()),
240                rex_node: Some(RexNode::FuncCall(FunctionCall {
241                    children: vec![left_expr_input_ref, right_expr_input_ref],
242                })),
243            };
244            while let Some((expr_type, return_type)) = stack.pop() {
245                root_expr_node = ExprNode {
246                    function_type: expr_type as i32,
247                    return_type: Some(return_type.to_protobuf()),
248                    rex_node: Some(RexNode::FuncCall(FunctionCall {
249                        children: vec![root_expr_node],
250                    })),
251                }
252            }
253            build_child(&root_expr_node)?
254        };
255
256        Ok(SomeAllExpression::new(
257            left_expr,
258            right_expr,
259            outer_expr_type,
260            eval_func,
261        ))
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use risingwave_common::row::Row;
268    use risingwave_common::test_prelude::DataChunkTestExt;
269    use risingwave_common::types::ToOwnedDatum;
270    use risingwave_common::util::iter_util::ZipEqDebug;
271    use risingwave_expr::expr::build_from_pretty;
272
273    use super::*;
274
275    #[tokio::test]
276    async fn test_some() {
277        let expr = SomeAllExpression::new(
278            build_from_pretty("0:int4"),
279            build_from_pretty("$0:boolean"),
280            Type::Some,
281            build_from_pretty("$1:boolean"),
282        );
283        let (input, expected) = DataChunk::from_pretty(
284            "B[]        B
285             .          .
286             {}         f
287             {NULL}     .
288             {NULL,f}   .
289             {NULL,t}   t
290             {t,f}      t
291             {f,t}      t", // <- regression test for #14214
292        )
293        .split_column_at(1);
294
295        // test eval
296        let output = expr.eval(&input).await.unwrap();
297        assert_eq!(&output, expected.column_at(0));
298
299        // test eval_row
300        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
301            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
302            assert_eq!(result, expected.datum_at(0).to_owned_datum());
303        }
304    }
305
306    #[tokio::test]
307    async fn test_all() {
308        let expr = SomeAllExpression::new(
309            build_from_pretty("0:int4"),
310            build_from_pretty("$0:boolean"),
311            Type::All,
312            build_from_pretty("$1:boolean"),
313        );
314        let (input, expected) = DataChunk::from_pretty(
315            "B[]        B
316             .          .
317             {}         t
318             {NULL}     .
319             {NULL,t}   .
320             {NULL,f}   f
321             {f,f}      f
322             {t}        t", // <- regression test for #14214
323        )
324        .split_column_at(1);
325
326        // test eval
327        let output = expr.eval(&input).await.unwrap();
328        assert_eq!(&output, expected.column_at(0));
329
330        // test eval_row
331        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
332            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
333            assert_eq!(result, expected.datum_at(0).to_owned_datum());
334        }
335    }
336}