1use itertools::Itertools;
16use risingwave_common::catalog::Schema;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::iter_util::ZipEqFast;
19
20use super::type_inference::cast;
21use super::{CastContext, CastError, Expr, ExprImpl, Literal, infer_some_all, infer_type};
22use crate::error::Result as RwResult;
23use crate::expr::{ExprDisplay, ExprType, bail_cast_error, is_impure_func_call};
24
25#[derive(Clone, Eq, PartialEq, Hash)]
26pub struct FunctionCall {
27 pub(super) func_type: ExprType,
28 pub(super) return_type: DataType,
29 pub(super) inputs: Vec<ExprImpl>,
30}
31
32fn debug_binary_op(
33 f: &mut std::fmt::Formatter<'_>,
34 op: &str,
35 inputs: &[ExprImpl],
36) -> std::fmt::Result {
37 use std::fmt::Debug;
38
39 assert_eq!(inputs.len(), 2);
40
41 write!(f, "(")?;
42 inputs[0].fmt(f)?;
43 write!(f, " {} ", op)?;
44 inputs[1].fmt(f)?;
45 write!(f, ")")?;
46
47 Ok(())
48}
49
50impl std::fmt::Debug for FunctionCall {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 if f.alternate() {
53 f.debug_struct("FunctionCall")
54 .field("func_type", &self.func_type)
55 .field("return_type", &self.return_type)
56 .field("inputs", &self.inputs)
57 .finish()
58 } else {
59 match &self.func_type {
60 ExprType::Cast => {
61 assert_eq!(self.inputs.len(), 1);
62 self.inputs[0].fmt(f)?;
63 write!(f, "::{:?}", self.return_type)
64 }
65 ExprType::Add => debug_binary_op(f, "+", &self.inputs),
66 ExprType::Subtract => debug_binary_op(f, "-", &self.inputs),
67 ExprType::Multiply => debug_binary_op(f, "*", &self.inputs),
68 ExprType::Divide => debug_binary_op(f, "/", &self.inputs),
69 ExprType::Modulus => debug_binary_op(f, "%", &self.inputs),
70 ExprType::Equal => debug_binary_op(f, "=", &self.inputs),
71 ExprType::NotEqual => debug_binary_op(f, "<>", &self.inputs),
72 ExprType::LessThan => debug_binary_op(f, "<", &self.inputs),
73 ExprType::LessThanOrEqual => debug_binary_op(f, "<=", &self.inputs),
74 ExprType::GreaterThan => debug_binary_op(f, ">", &self.inputs),
75 ExprType::GreaterThanOrEqual => debug_binary_op(f, ">=", &self.inputs),
76 ExprType::And => debug_binary_op(f, "AND", &self.inputs),
77 ExprType::Or => debug_binary_op(f, "OR", &self.inputs),
78 ExprType::BitwiseShiftLeft => debug_binary_op(f, "<<", &self.inputs),
79 ExprType::BitwiseShiftRight => debug_binary_op(f, ">>", &self.inputs),
80 ExprType::BitwiseAnd => debug_binary_op(f, "&", &self.inputs),
81 ExprType::BitwiseOr => debug_binary_op(f, "|", &self.inputs),
82 ExprType::BitwiseXor => debug_binary_op(f, "#", &self.inputs),
83 ExprType::ArrayContains => debug_binary_op(f, "@>", &self.inputs),
84 ExprType::ArrayContained => debug_binary_op(f, "<@", &self.inputs),
85 ExprType::ArrayOverlaps => debug_binary_op(f, "&&", &self.inputs),
86 _ => {
87 let func_name = format!("{:?}", self.func_type);
88 let mut builder = f.debug_tuple(&func_name);
89 self.inputs.iter().for_each(|child| {
90 builder.field(child);
91 });
92 builder.finish()
93 }
94 }
95 }
96 }
97}
98
99impl FunctionCall {
100 pub fn new(func_type: ExprType, mut inputs: Vec<ExprImpl>) -> RwResult<Self> {
106 let return_type = infer_type(func_type.into(), &mut inputs)?;
107 Ok(Self::new_unchecked(func_type, inputs, return_type))
108 }
109
110 pub fn cast_mut(
113 child: &mut ExprImpl,
114 target: &DataType,
115 allows: CastContext,
116 ) -> Result<(), CastError> {
117 if let ExprImpl::Parameter(expr) = child
118 && !expr.has_infer()
119 {
120 expr.cast_infer_type(target);
122 return Ok(());
123 }
124 if let ExprImpl::FunctionCall(func) = child
125 && func.func_type == ExprType::Row
126 {
127 return Self::cast_row_expr(func, target, allows);
131 }
132 if child.is_untyped() {
133 let literal = child.as_literal().unwrap();
135 let datum = literal
136 .get_data()
137 .as_ref()
138 .map(|scalar| ScalarImpl::from_text(scalar.as_utf8(), target))
139 .transpose();
140 if let Ok(datum) = datum {
141 *child = Literal::new(datum, target.clone()).into();
142 return Ok(());
143 }
144 }
147
148 let source = child.return_type();
149 if &source == target {
150 return Ok(());
151 }
152
153 if child.is_untyped() {
154 } else {
157 cast(&source, target, allows)?;
158 }
159
160 let owned = std::mem::replace(child, ExprImpl::literal_bool(false));
162 *child = Self::new_unchecked(ExprType::Cast, vec![owned], target.clone()).into();
163 Ok(())
164 }
165
166 fn cast_row_expr(
170 func: &mut FunctionCall,
171 target_type: &DataType,
172 allows: CastContext,
173 ) -> Result<(), CastError> {
174 let DataType::Struct(t) = &target_type else {
176 bail_cast_error!(
177 "cannot cast type \"{}\" to \"{}\"",
178 func.return_type(), target_type,
180 );
181 };
182
183 let expected_len = t.len();
184 let actual_len = func.inputs.len();
185
186 match expected_len.cmp(&actual_len) {
187 std::cmp::Ordering::Equal => {
188 func.inputs
191 .iter_mut()
192 .zip_eq_fast(t.types())
193 .try_for_each(|(e, t)| Self::cast_mut(e, t, allows))?;
194 func.return_type = target_type.clone();
195 Ok(())
196 }
197 std::cmp::Ordering::Less => bail_cast_error!(
198 "input has too many columns, expected {expected_len} but got {actual_len}"
199 ),
200 std::cmp::Ordering::Greater => bail_cast_error!(
201 "input has too few columns, expected {expected_len} but got {actual_len}"
202 ),
203 }
204 }
205
206 pub fn new_unchecked(
209 func_type: ExprType,
210 inputs: Vec<ExprImpl>,
211 return_type: DataType,
212 ) -> Self {
213 FunctionCall {
214 func_type,
215 return_type,
216 inputs,
217 }
218 }
219
220 pub fn new_binary_op_func(
221 mut func_types: Vec<ExprType>,
222 mut inputs: Vec<ExprImpl>,
223 ) -> RwResult<ExprImpl> {
224 let expr_type = func_types.remove(0);
225 match expr_type {
226 ExprType::Some | ExprType::All => {
227 let return_type = infer_some_all(func_types, &mut inputs)?;
228 Ok(FunctionCall::new_unchecked(expr_type, inputs, return_type).into())
229 }
230 ExprType::Not | ExprType::IsNotNull | ExprType::IsNull | ExprType::Neg => {
231 Ok(FunctionCall::new(
232 expr_type,
233 vec![Self::new_binary_op_func(func_types, inputs)?],
234 )?
235 .into())
236 }
237 _ => Ok(FunctionCall::new(expr_type, inputs)?.into()),
238 }
239 }
240
241 pub fn decompose(self) -> (ExprType, Vec<ExprImpl>, DataType) {
242 (self.func_type, self.inputs, self.return_type)
243 }
244
245 pub fn decompose_as_binary(self) -> (ExprType, ExprImpl, ExprImpl) {
246 assert_eq!(self.inputs.len(), 2);
247 let mut iter = self.inputs.into_iter();
248 let left = iter.next().unwrap();
249 let right = iter.next().unwrap();
250 (self.func_type, left, right)
251 }
252
253 pub fn decompose_as_unary(self) -> (ExprType, ExprImpl) {
254 assert_eq!(self.inputs.len(), 1);
255 let mut iter = self.inputs.into_iter();
256 let input = iter.next().unwrap();
257 (self.func_type, input)
258 }
259
260 pub fn func_type(&self) -> ExprType {
261 self.func_type
262 }
263
264 pub fn inputs(&self) -> &[ExprImpl] {
266 self.inputs.as_ref()
267 }
268
269 pub fn inputs_mut(&mut self) -> &mut [ExprImpl] {
270 self.inputs.as_mut()
271 }
272
273 pub(super) fn from_expr_proto(
274 function_call: &risingwave_pb::expr::FunctionCall,
275 func_type: ExprType,
276 return_type: DataType,
277 ) -> RwResult<Self> {
278 let inputs: Vec<_> = function_call
279 .get_children()
280 .iter()
281 .map(ExprImpl::from_expr_proto)
282 .try_collect()?;
283 Ok(Self {
284 func_type,
285 return_type,
286 inputs,
287 })
288 }
289
290 pub fn is_pure(&self) -> bool {
291 !is_impure_func_call(self)
292 }
293}
294
295impl Expr for FunctionCall {
296 fn return_type(&self) -> DataType {
297 self.return_type.clone()
298 }
299
300 fn try_to_expr_proto(&self) -> Result<risingwave_pb::expr::ExprNode, String> {
301 use risingwave_pb::expr::expr_node::*;
302 use risingwave_pb::expr::*;
303
304 let children = self
305 .inputs()
306 .iter()
307 .map(|input| input.try_to_expr_proto())
308 .try_collect()?;
309
310 Ok(ExprNode {
311 function_type: self.func_type().into(),
312 return_type: Some(self.return_type().to_protobuf()),
313 rex_node: Some(RexNode::FuncCall(FunctionCall { children })),
314 })
315 }
316}
317
318pub struct FunctionCallDisplay<'a> {
319 pub function_call: &'a FunctionCall,
320 pub input_schema: &'a Schema,
321}
322
323impl std::fmt::Debug for FunctionCallDisplay<'_> {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 let that = self.function_call;
326 match &that.func_type {
327 ExprType::Cast => {
328 assert_eq!(that.inputs.len(), 1);
329 ExprDisplay {
330 expr: &that.inputs[0],
331 input_schema: self.input_schema,
332 }
333 .fmt(f)?;
334 write!(f, "::{:?}", that.return_type)
335 }
336 ExprType::Add => explain_verbose_binary_op(f, "+", &that.inputs, self.input_schema),
337 ExprType::Subtract => {
338 explain_verbose_binary_op(f, "-", &that.inputs, self.input_schema)
339 }
340 ExprType::Multiply => {
341 explain_verbose_binary_op(f, "*", &that.inputs, self.input_schema)
342 }
343 ExprType::Divide => explain_verbose_binary_op(f, "/", &that.inputs, self.input_schema),
344 ExprType::Modulus => explain_verbose_binary_op(f, "%", &that.inputs, self.input_schema),
345 ExprType::Equal => explain_verbose_binary_op(f, "=", &that.inputs, self.input_schema),
346 ExprType::NotEqual => {
347 explain_verbose_binary_op(f, "<>", &that.inputs, self.input_schema)
348 }
349 ExprType::LessThan => {
350 explain_verbose_binary_op(f, "<", &that.inputs, self.input_schema)
351 }
352 ExprType::LessThanOrEqual => {
353 explain_verbose_binary_op(f, "<=", &that.inputs, self.input_schema)
354 }
355 ExprType::GreaterThan => {
356 explain_verbose_binary_op(f, ">", &that.inputs, self.input_schema)
357 }
358 ExprType::GreaterThanOrEqual => {
359 explain_verbose_binary_op(f, ">=", &that.inputs, self.input_schema)
360 }
361 ExprType::And => explain_verbose_binary_op(f, "AND", &that.inputs, self.input_schema),
362 ExprType::Or => explain_verbose_binary_op(f, "OR", &that.inputs, self.input_schema),
363 ExprType::BitwiseShiftLeft => {
364 explain_verbose_binary_op(f, "<<", &that.inputs, self.input_schema)
365 }
366 ExprType::BitwiseShiftRight => {
367 explain_verbose_binary_op(f, ">>", &that.inputs, self.input_schema)
368 }
369 ExprType::BitwiseAnd => {
370 explain_verbose_binary_op(f, "&", &that.inputs, self.input_schema)
371 }
372 ExprType::BitwiseOr => {
373 explain_verbose_binary_op(f, "|", &that.inputs, self.input_schema)
374 }
375 ExprType::BitwiseXor => {
376 explain_verbose_binary_op(f, "#", &that.inputs, self.input_schema)
377 }
378 ExprType::ArrayContains => {
379 explain_verbose_binary_op(f, "@>", &that.inputs, self.input_schema)
380 }
381 ExprType::ArrayContained => {
382 explain_verbose_binary_op(f, "<@", &that.inputs, self.input_schema)
383 }
384 ExprType::ArrayOverlaps => {
385 explain_verbose_binary_op(f, "&&", &that.inputs, self.input_schema)
386 }
387 ExprType::Proctime => {
388 write!(f, "{:?}", that.func_type)
389 }
390 _ => {
391 let func_name = format!("{:?}", that.func_type);
392 let mut builder = f.debug_tuple(&func_name);
393 that.inputs.iter().for_each(|child| {
394 builder.field(&ExprDisplay {
395 expr: child,
396 input_schema: self.input_schema,
397 });
398 });
399 builder.finish()
400 }
401 }
402 }
403}
404
405fn explain_verbose_binary_op(
406 f: &mut std::fmt::Formatter<'_>,
407 op: &str,
408 inputs: &[ExprImpl],
409 input_schema: &Schema,
410) -> std::fmt::Result {
411 use std::fmt::Debug;
412
413 assert_eq!(inputs.len(), 2);
414
415 write!(f, "(")?;
416 ExprDisplay {
417 expr: &inputs[0],
418 input_schema,
419 }
420 .fmt(f)?;
421 write!(f, " {} ", op)?;
422 ExprDisplay {
423 expr: &inputs[1],
424 input_schema,
425 }
426 .fmt(f)?;
427 write!(f, ")")?;
428
429 Ok(())
430}
431
432pub fn is_row_function(expr: &ExprImpl) -> bool {
433 if let ExprImpl::FunctionCall(func) = expr
434 && func.func_type() == ExprType::Row
435 {
436 return true;
437 }
438 false
439}