risingwave_expr_impl/scalar/
coalesce.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::BitAnd;
16use std::sync::Arc;
17
18use risingwave_common::array::{ArrayRef, DataChunk};
19use risingwave_common::row::OwnedRow;
20use risingwave_common::types::{DataType, Datum};
21use risingwave_expr::expr::{BoxedExpression, Expression};
22use risingwave_expr::{Result, build_function};
23
24#[derive(Debug)]
25pub struct CoalesceExpression {
26    return_type: DataType,
27    children: Vec<BoxedExpression>,
28}
29
30#[async_trait::async_trait]
31impl Expression for CoalesceExpression {
32    fn return_type(&self) -> DataType {
33        self.return_type.clone()
34    }
35
36    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
37        let init_vis = input.visibility();
38        let mut input = input.clone();
39        let len = input.capacity();
40        let mut selection: Vec<Option<usize>> = vec![None; len];
41        let mut children_array = Vec::with_capacity(self.children.len());
42        for (child_idx, child) in self.children.iter().enumerate() {
43            let res = child.eval(&input).await?;
44            let res_bitmap = res.null_bitmap();
45            let orig_vis = input.visibility();
46            for pos in orig_vis.bitand(res_bitmap).iter_ones() {
47                selection[pos] = Some(child_idx);
48            }
49            let new_vis = orig_vis & !res_bitmap;
50            input.set_visibility(new_vis);
51            children_array.push(res);
52        }
53        let mut builder = self.return_type.create_array_builder(len);
54        for (i, sel) in selection.iter().enumerate() {
55            if init_vis.is_set(i)
56                && let Some(child_idx) = sel
57            {
58                builder.append(children_array[*child_idx].value_at(i));
59            } else {
60                builder.append_null()
61            }
62        }
63        Ok(Arc::new(builder.finish()))
64    }
65
66    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
67        for child in &self.children {
68            let datum = child.eval_row(input).await?;
69            if datum.is_some() {
70                return Ok(datum);
71            }
72        }
73        Ok(None)
74    }
75}
76
77#[build_function("coalesce(...) -> any", type_infer = "unreachable")]
78fn build(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
79    Ok(Box::new(CoalesceExpression {
80        return_type,
81        children,
82    }))
83}
84
85#[cfg(test)]
86mod tests {
87    use risingwave_common::array::DataChunk;
88    use risingwave_common::row::Row;
89    use risingwave_common::test_prelude::DataChunkTestExt;
90    use risingwave_common::types::ToOwnedDatum;
91    use risingwave_common::util::iter_util::ZipEqDebug;
92    use risingwave_expr::expr::build_from_pretty;
93
94    #[tokio::test]
95    async fn test_coalesce_expr() {
96        let expr = build_from_pretty("(coalesce:int4 $0:int4 $1:int4 $2:int4)");
97        let (input, expected) = DataChunk::from_pretty(
98            "i i i i
99             1 . . 1
100             . 2 . 2
101             . . 3 3
102             . . . .",
103        )
104        .split_column_at(3);
105
106        // test eval
107        let output = expr.eval(&input).await.unwrap();
108        assert_eq!(&output, expected.column_at(0));
109
110        // test eval_row
111        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
112            let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
113            assert_eq!(result, expected.datum_at(0).to_owned_datum());
114        }
115    }
116}