risingwave_expr_impl/scalar/
case.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use risingwave_common::array::{ArrayRef, DataChunk};
19use risingwave_common::bail;
20use risingwave_common::row::{OwnedRow, Row};
21use risingwave_common::types::{DataType, Datum, ScalarImpl};
22use risingwave_expr::expr::{BoxedExpression, Expression};
23use risingwave_expr::{Result, build_function};
24
25#[derive(Debug)]
26struct WhenClause {
27    when: BoxedExpression,
28    then: BoxedExpression,
29}
30
31#[derive(Debug)]
32struct CaseExpression {
33    return_type: DataType,
34    when_clauses: Vec<WhenClause>,
35    else_clause: Option<BoxedExpression>,
36}
37
38impl CaseExpression {
39    fn new(
40        return_type: DataType,
41        when_clauses: Vec<WhenClause>,
42        else_clause: Option<BoxedExpression>,
43    ) -> Self {
44        Self {
45            return_type,
46            when_clauses,
47            else_clause,
48        }
49    }
50}
51
52#[async_trait::async_trait]
53impl Expression for CaseExpression {
54    fn return_type(&self) -> DataType {
55        self.return_type.clone()
56    }
57
58    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
59        let mut input = input.clone();
60        let input_len = input.capacity();
61        let mut selection = vec![None; input_len];
62        let when_len = self.when_clauses.len();
63        let mut result_array = Vec::with_capacity(when_len + 1);
64        for (when_idx, WhenClause { when, then }) in self.when_clauses.iter().enumerate() {
65            let input_vis = input.visibility().clone();
66            // note that evaluated result from when clause may contain bits that are not visible,
67            // so we need to mask it with input visibility.
68            let calc_then_vis = when.eval(&input).await?.as_bool().to_bitmap() & &input_vis;
69            input.set_visibility(calc_then_vis.clone());
70            let then_res = then.eval(&input).await?;
71            calc_then_vis
72                .iter_ones()
73                .for_each(|pos| selection[pos] = Some(when_idx));
74            input.set_visibility(&input_vis & (!calc_then_vis));
75            result_array.push(then_res);
76        }
77        if let Some(ref else_expr) = self.else_clause {
78            let else_res = else_expr.eval(&input).await?;
79            input
80                .visibility()
81                .iter_ones()
82                .for_each(|pos| selection[pos] = Some(when_len));
83            result_array.push(else_res);
84        }
85        let mut builder = self.return_type().create_array_builder(input.capacity());
86        for (i, sel) in selection.into_iter().enumerate() {
87            if let Some(when_idx) = sel {
88                builder.append(result_array[when_idx].value_at(i));
89            } else {
90                builder.append_null();
91            }
92        }
93        Ok(Arc::new(builder.finish()))
94    }
95
96    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
97        for WhenClause { when, then } in &self.when_clauses {
98            if when.eval_row(input).await?.is_some_and(|w| w.into_bool()) {
99                return then.eval_row(input).await;
100            }
101        }
102        if let Some(ref else_expr) = self.else_clause {
103            else_expr.eval_row(input).await
104        } else {
105            Ok(None)
106        }
107    }
108}
109
110/// With large scale of simple form match arms in case-when expression,
111/// we could optimize the `CaseExpression` to `ConstantLookupExpression`,
112/// which could significantly facilitate the evaluation of case-when.
113#[derive(Debug)]
114struct ConstantLookupExpression {
115    return_type: DataType,
116    arms: HashMap<ScalarImpl, BoxedExpression>,
117    fallback: Option<BoxedExpression>,
118    /// `operand` must exist at present
119    operand: BoxedExpression,
120}
121
122impl ConstantLookupExpression {
123    fn new(
124        return_type: DataType,
125        arms: HashMap<ScalarImpl, BoxedExpression>,
126        fallback: Option<BoxedExpression>,
127        operand: BoxedExpression,
128    ) -> Self {
129        Self {
130            return_type,
131            arms,
132            fallback,
133            operand,
134        }
135    }
136
137    /// Evaluate the fallback arm with the given input
138    async fn eval_fallback(&self, input: &OwnedRow) -> Result<Datum> {
139        let Some(ref fallback) = self.fallback else {
140            return Ok(None);
141        };
142        let Ok(res) = fallback.eval_row(input).await else {
143            bail!("failed to evaluate the input for fallback arm");
144        };
145        Ok(res)
146    }
147
148    /// The actual lookup & evaluation logic
149    /// used in both `eval_row` & `eval`
150    async fn lookup(&self, datum: Datum, input: &OwnedRow) -> Result<Datum> {
151        if datum.is_none() {
152            return self.eval_fallback(input).await;
153        }
154
155        if let Some(expr) = self.arms.get(datum.as_ref().unwrap()) {
156            let Ok(res) = expr.eval_row(input).await else {
157                bail!("failed to evaluate the input for normal arm");
158            };
159            Ok(res)
160        } else {
161            // Fallback arm goes here
162            self.eval_fallback(input).await
163        }
164    }
165}
166
167#[async_trait::async_trait]
168impl Expression for ConstantLookupExpression {
169    fn return_type(&self) -> DataType {
170        self.return_type.clone()
171    }
172
173    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
174        let input_len = input.capacity();
175        let mut builder = self.return_type().create_array_builder(input_len);
176
177        // Evaluate the input DataChunk at first
178        let eval_result = self.operand.eval(input).await?;
179
180        for i in 0..input_len {
181            let datum = eval_result.datum_at(i);
182            let (row, vis) = input.row_at(i);
183
184            // Check for visibility
185            if !vis {
186                builder.append_null();
187                continue;
188            }
189
190            // Note that the `owned_row` here is extracted from input
191            // rather than from `eval_result`
192            let owned_row = row.into_owned_row();
193
194            // Lookup and evaluate the current input datum
195            if let Ok(datum) = self.lookup(datum, &owned_row).await {
196                builder.append(datum.as_ref());
197            } else {
198                bail!("failed to lookup and evaluate the expression in `eval`");
199            }
200        }
201
202        Ok(Arc::new(builder.finish()))
203    }
204
205    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
206        let datum = self.operand.eval_row(input).await?;
207        self.lookup(datum, input).await
208    }
209}
210
211#[build_function("constant_lookup(...) -> any", type_infer = "unreachable")]
212fn build_constant_lookup_expr(
213    return_type: DataType,
214    children: Vec<BoxedExpression>,
215) -> Result<BoxedExpression> {
216    if children.is_empty() {
217        bail!("children expression must not be empty for constant lookup expression");
218    }
219
220    let mut children = children;
221
222    let operand = children.remove(0);
223
224    let mut arms = HashMap::new();
225
226    // Build the `arms` with iterating over `when` & `then` clauses
227    let mut iter = children.into_iter().array_chunks();
228    for [when, then] in iter.by_ref() {
229        let Ok(Some(s)) = when.eval_const() else {
230            bail!("expect when expression to be const");
231        };
232        arms.insert(s, then);
233    }
234
235    let fallback = if let Some(else_clause) = iter.into_remainder().unwrap().next() {
236        if else_clause.return_type() != return_type {
237            bail!("Type mismatched between else and case.");
238        }
239        Some(else_clause)
240    } else {
241        None
242    };
243
244    Ok(Box::new(ConstantLookupExpression::new(
245        return_type,
246        arms,
247        fallback,
248        operand,
249    )))
250}
251
252#[build_function("case(...) -> any", type_infer = "unreachable")]
253fn build_case_expr(
254    return_type: DataType,
255    children: Vec<BoxedExpression>,
256) -> Result<BoxedExpression> {
257    // children: (when, then)+, (else_clause)?
258    let len = children.len();
259    let mut when_clauses = Vec::with_capacity(len / 2);
260    let mut iter = children.into_iter().array_chunks();
261    for [when, then] in iter.by_ref() {
262        if when.return_type() != DataType::Boolean {
263            bail!("Type mismatched between when clause and condition");
264        }
265        if then.return_type() != return_type {
266            bail!("Type mismatched between then clause and case");
267        }
268        when_clauses.push(WhenClause { when, then });
269    }
270    let else_clause = if let Some(else_clause) = iter.into_remainder().unwrap().next() {
271        if else_clause.return_type() != return_type {
272            bail!("Type mismatched between else and case.");
273        }
274        Some(else_clause)
275    } else {
276        None
277    };
278
279    Ok(Box::new(CaseExpression::new(
280        return_type,
281        when_clauses,
282        else_clause,
283    )))
284}
285
286#[cfg(test)]
287mod tests {
288    use risingwave_common::test_prelude::DataChunkTestExt;
289    use risingwave_common::types::ToOwnedDatum;
290    use risingwave_common::util::iter_util::ZipEqDebug;
291    use risingwave_expr::expr::build_from_pretty;
292
293    use super::*;
294
295    #[tokio::test]
296    async fn test_eval_searched_case() {
297        // when x then 1 else 2
298        let case = build_from_pretty("(case:int4 $0:boolean 1:int4 2:int4)");
299        let (input, expected) = DataChunk::from_pretty(
300            "B i
301             t 1
302             f 2
303             t 1
304             t 1
305             f 2",
306        )
307        .split_column_at(1);
308
309        // test eval
310        let output = case.eval(&input).await.unwrap();
311        assert_eq!(&output, expected.column_at(0));
312
313        // test eval_row
314        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
315            let result = case.eval_row(&row.to_owned_row()).await.unwrap();
316            assert_eq!(result, expected.datum_at(0).to_owned_datum());
317        }
318    }
319
320    #[tokio::test]
321    async fn test_eval_without_else() {
322        // when x then 1 when y then 2
323        let case = build_from_pretty("(case:int4 $0:boolean 1:int4 $1:boolean 2:int4)");
324        let (input, expected) = DataChunk::from_pretty(
325            "B B i
326             f f .
327             f t 2
328             t f 1
329             t t 1",
330        )
331        .split_column_at(2);
332
333        // test eval
334        let output = case.eval(&input).await.unwrap();
335        assert_eq!(&output, expected.column_at(0));
336
337        // test eval_row
338        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
339            let result = case.eval_row(&row.to_owned_row()).await.unwrap();
340            assert_eq!(result, expected.datum_at(0).to_owned_datum());
341        }
342    }
343}