risingwave_expr_impl/scalar/
case.rs1use std::collections::HashMap;
16use std::sync::Arc;
17
18use risingwave_common::array::{ArrayRef, DataChunk};
19use risingwave_common::bail;
20use risingwave_common::row::{OwnedRow, Row};
21use risingwave_common::types::{DataType, Datum, ScalarImpl};
22use risingwave_expr::expr::{BoxedExpression, Expression};
23use risingwave_expr::{Result, build_function};
24
25#[derive(Debug)]
26struct WhenClause {
27 when: BoxedExpression,
28 then: BoxedExpression,
29}
30
31#[derive(Debug)]
32struct CaseExpression {
33 return_type: DataType,
34 when_clauses: Vec<WhenClause>,
35 else_clause: Option<BoxedExpression>,
36}
37
38impl CaseExpression {
39 fn new(
40 return_type: DataType,
41 when_clauses: Vec<WhenClause>,
42 else_clause: Option<BoxedExpression>,
43 ) -> Self {
44 Self {
45 return_type,
46 when_clauses,
47 else_clause,
48 }
49 }
50}
51
52#[async_trait::async_trait]
53impl Expression for CaseExpression {
54 fn return_type(&self) -> DataType {
55 self.return_type.clone()
56 }
57
58 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
59 let mut input = input.clone();
60 let input_len = input.capacity();
61 let mut selection = vec![None; input_len];
62 let when_len = self.when_clauses.len();
63 let mut result_array = Vec::with_capacity(when_len + 1);
64 for (when_idx, WhenClause { when, then }) in self.when_clauses.iter().enumerate() {
65 let input_vis = input.visibility().clone();
66 let calc_then_vis = when.eval(&input).await?.as_bool().to_bitmap() & &input_vis;
69 input.set_visibility(calc_then_vis.clone());
70 let then_res = then.eval(&input).await?;
71 calc_then_vis
72 .iter_ones()
73 .for_each(|pos| selection[pos] = Some(when_idx));
74 input.set_visibility(&input_vis & (!calc_then_vis));
75 result_array.push(then_res);
76 }
77 if let Some(ref else_expr) = self.else_clause {
78 let else_res = else_expr.eval(&input).await?;
79 input
80 .visibility()
81 .iter_ones()
82 .for_each(|pos| selection[pos] = Some(when_len));
83 result_array.push(else_res);
84 }
85 let mut builder = self.return_type().create_array_builder(input.capacity());
86 for (i, sel) in selection.into_iter().enumerate() {
87 if let Some(when_idx) = sel {
88 builder.append(result_array[when_idx].value_at(i));
89 } else {
90 builder.append_null();
91 }
92 }
93 Ok(Arc::new(builder.finish()))
94 }
95
96 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
97 for WhenClause { when, then } in &self.when_clauses {
98 if when.eval_row(input).await?.is_some_and(|w| w.into_bool()) {
99 return then.eval_row(input).await;
100 }
101 }
102 if let Some(ref else_expr) = self.else_clause {
103 else_expr.eval_row(input).await
104 } else {
105 Ok(None)
106 }
107 }
108}
109
110#[derive(Debug)]
114struct ConstantLookupExpression {
115 return_type: DataType,
116 arms: HashMap<ScalarImpl, BoxedExpression>,
117 fallback: Option<BoxedExpression>,
118 operand: BoxedExpression,
120}
121
122impl ConstantLookupExpression {
123 fn new(
124 return_type: DataType,
125 arms: HashMap<ScalarImpl, BoxedExpression>,
126 fallback: Option<BoxedExpression>,
127 operand: BoxedExpression,
128 ) -> Self {
129 Self {
130 return_type,
131 arms,
132 fallback,
133 operand,
134 }
135 }
136
137 async fn eval_fallback(&self, input: &OwnedRow) -> Result<Datum> {
139 let Some(ref fallback) = self.fallback else {
140 return Ok(None);
141 };
142 let Ok(res) = fallback.eval_row(input).await else {
143 bail!("failed to evaluate the input for fallback arm");
144 };
145 Ok(res)
146 }
147
148 async fn lookup(&self, datum: Datum, input: &OwnedRow) -> Result<Datum> {
151 if datum.is_none() {
152 return self.eval_fallback(input).await;
153 }
154
155 if let Some(expr) = self.arms.get(datum.as_ref().unwrap()) {
156 let Ok(res) = expr.eval_row(input).await else {
157 bail!("failed to evaluate the input for normal arm");
158 };
159 Ok(res)
160 } else {
161 self.eval_fallback(input).await
163 }
164 }
165}
166
167#[async_trait::async_trait]
168impl Expression for ConstantLookupExpression {
169 fn return_type(&self) -> DataType {
170 self.return_type.clone()
171 }
172
173 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
174 let input_len = input.capacity();
175 let mut builder = self.return_type().create_array_builder(input_len);
176
177 let eval_result = self.operand.eval(input).await?;
179
180 for i in 0..input_len {
181 let datum = eval_result.datum_at(i);
182 let (row, vis) = input.row_at(i);
183
184 if !vis {
186 builder.append_null();
187 continue;
188 }
189
190 let owned_row = row.into_owned_row();
193
194 if let Ok(datum) = self.lookup(datum, &owned_row).await {
196 builder.append(datum.as_ref());
197 } else {
198 bail!("failed to lookup and evaluate the expression in `eval`");
199 }
200 }
201
202 Ok(Arc::new(builder.finish()))
203 }
204
205 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
206 let datum = self.operand.eval_row(input).await?;
207 self.lookup(datum, input).await
208 }
209}
210
211#[build_function("constant_lookup(...) -> any", type_infer = "unreachable")]
212fn build_constant_lookup_expr(
213 return_type: DataType,
214 children: Vec<BoxedExpression>,
215) -> Result<BoxedExpression> {
216 if children.is_empty() {
217 bail!("children expression must not be empty for constant lookup expression");
218 }
219
220 let mut children = children;
221
222 let operand = children.remove(0);
223
224 let mut arms = HashMap::new();
225
226 let mut iter = children.into_iter().array_chunks();
228 for [when, then] in iter.by_ref() {
229 let Ok(Some(s)) = when.eval_const() else {
230 bail!("expect when expression to be const");
231 };
232 arms.insert(s, then);
233 }
234
235 let fallback = if let Some(else_clause) = iter.into_remainder().unwrap().next() {
236 if else_clause.return_type() != return_type {
237 bail!("Type mismatched between else and case.");
238 }
239 Some(else_clause)
240 } else {
241 None
242 };
243
244 Ok(Box::new(ConstantLookupExpression::new(
245 return_type,
246 arms,
247 fallback,
248 operand,
249 )))
250}
251
252#[build_function("case(...) -> any", type_infer = "unreachable")]
253fn build_case_expr(
254 return_type: DataType,
255 children: Vec<BoxedExpression>,
256) -> Result<BoxedExpression> {
257 let len = children.len();
259 let mut when_clauses = Vec::with_capacity(len / 2);
260 let mut iter = children.into_iter().array_chunks();
261 for [when, then] in iter.by_ref() {
262 if when.return_type() != DataType::Boolean {
263 bail!("Type mismatched between when clause and condition");
264 }
265 if then.return_type() != return_type {
266 bail!("Type mismatched between then clause and case");
267 }
268 when_clauses.push(WhenClause { when, then });
269 }
270 let else_clause = if let Some(else_clause) = iter.into_remainder().unwrap().next() {
271 if else_clause.return_type() != return_type {
272 bail!("Type mismatched between else and case.");
273 }
274 Some(else_clause)
275 } else {
276 None
277 };
278
279 Ok(Box::new(CaseExpression::new(
280 return_type,
281 when_clauses,
282 else_clause,
283 )))
284}
285
286#[cfg(test)]
287mod tests {
288 use risingwave_common::test_prelude::DataChunkTestExt;
289 use risingwave_common::types::ToOwnedDatum;
290 use risingwave_common::util::iter_util::ZipEqDebug;
291 use risingwave_expr::expr::build_from_pretty;
292
293 use super::*;
294
295 #[tokio::test]
296 async fn test_eval_searched_case() {
297 let case = build_from_pretty("(case:int4 $0:boolean 1:int4 2:int4)");
299 let (input, expected) = DataChunk::from_pretty(
300 "B i
301 t 1
302 f 2
303 t 1
304 t 1
305 f 2",
306 )
307 .split_column_at(1);
308
309 let output = case.eval(&input).await.unwrap();
311 assert_eq!(&output, expected.column_at(0));
312
313 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
315 let result = case.eval_row(&row.to_owned_row()).await.unwrap();
316 assert_eq!(result, expected.datum_at(0).to_owned_datum());
317 }
318 }
319
320 #[tokio::test]
321 async fn test_eval_without_else() {
322 let case = build_from_pretty("(case:int4 $0:boolean 1:int4 $1:boolean 2:int4)");
324 let (input, expected) = DataChunk::from_pretty(
325 "B B i
326 f f .
327 f t 2
328 t f 1
329 t t 1",
330 )
331 .split_column_at(2);
332
333 let output = case.eval(&input).await.unwrap();
335 assert_eq!(&output, expected.column_at(0));
336
337 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
339 let result = case.eval_row(&row.to_owned_row()).await.unwrap();
340 assert_eq!(result, expected.datum_at(0).to_owned_datum());
341 }
342 }
343}