1use itertools::Itertools;
16use risingwave_common::catalog::Schema;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::iter_util::ZipEqFast;
19
20use super::type_inference::cast;
21use super::{CastContext, CastError, Expr, ExprImpl, Literal, infer_some_all, infer_type};
22use crate::error::Result as RwResult;
23use crate::expr::{ExprDisplay, ExprType, ExprVisitor, ImpureAnalyzer, bail_cast_error};
24
25#[derive(Clone, Eq, PartialEq, Hash)]
26pub struct FunctionCall {
27 pub(super) func_type: ExprType,
28 pub(super) return_type: DataType,
29 pub(super) inputs: Vec<ExprImpl>,
30}
31
32fn debug_binary_op(
33 f: &mut std::fmt::Formatter<'_>,
34 op: &str,
35 inputs: &[ExprImpl],
36) -> std::fmt::Result {
37 use std::fmt::Debug;
38
39 assert_eq!(inputs.len(), 2);
40
41 write!(f, "(")?;
42 inputs[0].fmt(f)?;
43 write!(f, " {} ", op)?;
44 inputs[1].fmt(f)?;
45 write!(f, ")")?;
46
47 Ok(())
48}
49
50impl std::fmt::Debug for FunctionCall {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 if f.alternate() {
53 f.debug_struct("FunctionCall")
54 .field("func_type", &self.func_type)
55 .field("return_type", &self.return_type)
56 .field("inputs", &self.inputs)
57 .finish()
58 } else {
59 match &self.func_type {
60 ExprType::Cast => {
61 assert_eq!(self.inputs.len(), 1);
62 self.inputs[0].fmt(f)?;
63 write!(f, "::{:?}", self.return_type)
64 }
65 ExprType::Add => debug_binary_op(f, "+", &self.inputs),
66 ExprType::Subtract => debug_binary_op(f, "-", &self.inputs),
67 ExprType::Multiply => debug_binary_op(f, "*", &self.inputs),
68 ExprType::Divide => debug_binary_op(f, "/", &self.inputs),
69 ExprType::Modulus => debug_binary_op(f, "%", &self.inputs),
70 ExprType::Equal => debug_binary_op(f, "=", &self.inputs),
71 ExprType::NotEqual => debug_binary_op(f, "<>", &self.inputs),
72 ExprType::LessThan => debug_binary_op(f, "<", &self.inputs),
73 ExprType::LessThanOrEqual => debug_binary_op(f, "<=", &self.inputs),
74 ExprType::GreaterThan => debug_binary_op(f, ">", &self.inputs),
75 ExprType::GreaterThanOrEqual => debug_binary_op(f, ">=", &self.inputs),
76 ExprType::And => debug_binary_op(f, "AND", &self.inputs),
77 ExprType::Or => debug_binary_op(f, "OR", &self.inputs),
78 ExprType::BitwiseShiftLeft => debug_binary_op(f, "<<", &self.inputs),
79 ExprType::BitwiseShiftRight => debug_binary_op(f, ">>", &self.inputs),
80 ExprType::BitwiseAnd => debug_binary_op(f, "&", &self.inputs),
81 ExprType::BitwiseOr => debug_binary_op(f, "|", &self.inputs),
82 ExprType::BitwiseXor => debug_binary_op(f, "#", &self.inputs),
83 ExprType::ArrayContains => debug_binary_op(f, "@>", &self.inputs),
84 ExprType::ArrayContained => debug_binary_op(f, "<@", &self.inputs),
85 _ => {
86 let func_name = format!("{:?}", self.func_type);
87 let mut builder = f.debug_tuple(&func_name);
88 self.inputs.iter().for_each(|child| {
89 builder.field(child);
90 });
91 builder.finish()
92 }
93 }
94 }
95 }
96}
97
98impl FunctionCall {
99 pub fn new(func_type: ExprType, mut inputs: Vec<ExprImpl>) -> RwResult<Self> {
105 let return_type = infer_type(func_type.into(), &mut inputs)?;
106 Ok(Self::new_unchecked(func_type, inputs, return_type))
107 }
108
109 pub fn cast_mut(
112 child: &mut ExprImpl,
113 target: &DataType,
114 allows: CastContext,
115 ) -> Result<(), CastError> {
116 if let ExprImpl::Parameter(expr) = child
117 && !expr.has_infer()
118 {
119 expr.cast_infer_type(target);
121 return Ok(());
122 }
123 if let ExprImpl::FunctionCall(func) = child
124 && func.func_type == ExprType::Row
125 {
126 return Self::cast_row_expr(func, target, allows);
130 }
131 if child.is_untyped() {
132 let literal = child.as_literal().unwrap();
134 let datum = literal
135 .get_data()
136 .as_ref()
137 .map(|scalar| ScalarImpl::from_text(scalar.as_utf8(), target))
138 .transpose();
139 if let Ok(datum) = datum {
140 *child = Literal::new(datum, target.clone()).into();
141 return Ok(());
142 }
143 }
146
147 let source = child.return_type();
148 if &source == target {
149 return Ok(());
150 }
151
152 if child.is_untyped() {
153 } else {
156 cast(&source, target, allows)?;
157 }
158
159 let owned = std::mem::replace(child, ExprImpl::literal_bool(false));
161 *child = Self::new_unchecked(ExprType::Cast, vec![owned], target.clone()).into();
162 Ok(())
163 }
164
165 fn cast_row_expr(
169 func: &mut FunctionCall,
170 target_type: &DataType,
171 allows: CastContext,
172 ) -> Result<(), CastError> {
173 let DataType::Struct(t) = &target_type else {
175 bail_cast_error!(
176 "cannot cast type \"{}\" to \"{}\"",
177 func.return_type(), target_type,
179 );
180 };
181
182 let expected_len = t.len();
183 let actual_len = func.inputs.len();
184
185 match expected_len.cmp(&actual_len) {
186 std::cmp::Ordering::Equal => {
187 func.inputs
190 .iter_mut()
191 .zip_eq_fast(t.types())
192 .try_for_each(|(e, t)| Self::cast_mut(e, t, allows))?;
193 func.return_type = target_type.clone();
194 Ok(())
195 }
196 std::cmp::Ordering::Less => bail_cast_error!(
197 "input has too many columns, expected {expected_len} but got {actual_len}"
198 ),
199 std::cmp::Ordering::Greater => bail_cast_error!(
200 "input has too few columns, expected {expected_len} but got {actual_len}"
201 ),
202 }
203 }
204
205 pub fn new_unchecked(
208 func_type: ExprType,
209 inputs: Vec<ExprImpl>,
210 return_type: DataType,
211 ) -> Self {
212 FunctionCall {
213 func_type,
214 return_type,
215 inputs,
216 }
217 }
218
219 pub fn new_binary_op_func(
220 mut func_types: Vec<ExprType>,
221 mut inputs: Vec<ExprImpl>,
222 ) -> RwResult<ExprImpl> {
223 let expr_type = func_types.remove(0);
224 match expr_type {
225 ExprType::Some | ExprType::All => {
226 let return_type = infer_some_all(func_types, &mut inputs)?;
227 Ok(FunctionCall::new_unchecked(expr_type, inputs, return_type).into())
228 }
229 ExprType::Not | ExprType::IsNotNull | ExprType::IsNull | ExprType::Neg => {
230 Ok(FunctionCall::new(
231 expr_type,
232 vec![Self::new_binary_op_func(func_types, inputs)?],
233 )?
234 .into())
235 }
236 _ => Ok(FunctionCall::new(expr_type, inputs)?.into()),
237 }
238 }
239
240 pub fn decompose(self) -> (ExprType, Vec<ExprImpl>, DataType) {
241 (self.func_type, self.inputs, self.return_type)
242 }
243
244 pub fn decompose_as_binary(self) -> (ExprType, ExprImpl, ExprImpl) {
245 assert_eq!(self.inputs.len(), 2);
246 let mut iter = self.inputs.into_iter();
247 let left = iter.next().unwrap();
248 let right = iter.next().unwrap();
249 (self.func_type, left, right)
250 }
251
252 pub fn decompose_as_unary(self) -> (ExprType, ExprImpl) {
253 assert_eq!(self.inputs.len(), 1);
254 let mut iter = self.inputs.into_iter();
255 let input = iter.next().unwrap();
256 (self.func_type, input)
257 }
258
259 pub fn func_type(&self) -> ExprType {
260 self.func_type
261 }
262
263 pub fn inputs(&self) -> &[ExprImpl] {
265 self.inputs.as_ref()
266 }
267
268 pub fn inputs_mut(&mut self) -> &mut [ExprImpl] {
269 self.inputs.as_mut()
270 }
271
272 pub(super) fn from_expr_proto(
273 function_call: &risingwave_pb::expr::FunctionCall,
274 func_type: ExprType,
275 return_type: DataType,
276 ) -> RwResult<Self> {
277 let inputs: Vec<_> = function_call
278 .get_children()
279 .iter()
280 .map(ExprImpl::from_expr_proto)
281 .try_collect()?;
282 Ok(Self {
283 func_type,
284 return_type,
285 inputs,
286 })
287 }
288
289 pub fn is_pure(&self) -> bool {
290 let mut a = ImpureAnalyzer { impure: false };
291 a.visit_function_call(self);
292 !a.impure
293 }
294}
295
296impl Expr for FunctionCall {
297 fn return_type(&self) -> DataType {
298 self.return_type.clone()
299 }
300
301 fn try_to_expr_proto(&self) -> Result<risingwave_pb::expr::ExprNode, String> {
302 use risingwave_pb::expr::expr_node::*;
303 use risingwave_pb::expr::*;
304
305 let children = self
306 .inputs()
307 .iter()
308 .map(|input| input.try_to_expr_proto())
309 .try_collect()?;
310
311 Ok(ExprNode {
312 function_type: self.func_type().into(),
313 return_type: Some(self.return_type().to_protobuf()),
314 rex_node: Some(RexNode::FuncCall(FunctionCall { children })),
315 })
316 }
317}
318
319pub struct FunctionCallDisplay<'a> {
320 pub function_call: &'a FunctionCall,
321 pub input_schema: &'a Schema,
322}
323
324impl std::fmt::Debug for FunctionCallDisplay<'_> {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 let that = self.function_call;
327 match &that.func_type {
328 ExprType::Cast => {
329 assert_eq!(that.inputs.len(), 1);
330 ExprDisplay {
331 expr: &that.inputs[0],
332 input_schema: self.input_schema,
333 }
334 .fmt(f)?;
335 write!(f, "::{:?}", that.return_type)
336 }
337 ExprType::Add => explain_verbose_binary_op(f, "+", &that.inputs, self.input_schema),
338 ExprType::Subtract => {
339 explain_verbose_binary_op(f, "-", &that.inputs, self.input_schema)
340 }
341 ExprType::Multiply => {
342 explain_verbose_binary_op(f, "*", &that.inputs, self.input_schema)
343 }
344 ExprType::Divide => explain_verbose_binary_op(f, "/", &that.inputs, self.input_schema),
345 ExprType::Modulus => explain_verbose_binary_op(f, "%", &that.inputs, self.input_schema),
346 ExprType::Equal => explain_verbose_binary_op(f, "=", &that.inputs, self.input_schema),
347 ExprType::NotEqual => {
348 explain_verbose_binary_op(f, "<>", &that.inputs, self.input_schema)
349 }
350 ExprType::LessThan => {
351 explain_verbose_binary_op(f, "<", &that.inputs, self.input_schema)
352 }
353 ExprType::LessThanOrEqual => {
354 explain_verbose_binary_op(f, "<=", &that.inputs, self.input_schema)
355 }
356 ExprType::GreaterThan => {
357 explain_verbose_binary_op(f, ">", &that.inputs, self.input_schema)
358 }
359 ExprType::GreaterThanOrEqual => {
360 explain_verbose_binary_op(f, ">=", &that.inputs, self.input_schema)
361 }
362 ExprType::And => explain_verbose_binary_op(f, "AND", &that.inputs, self.input_schema),
363 ExprType::Or => explain_verbose_binary_op(f, "OR", &that.inputs, self.input_schema),
364 ExprType::BitwiseShiftLeft => {
365 explain_verbose_binary_op(f, "<<", &that.inputs, self.input_schema)
366 }
367 ExprType::BitwiseShiftRight => {
368 explain_verbose_binary_op(f, ">>", &that.inputs, self.input_schema)
369 }
370 ExprType::BitwiseAnd => {
371 explain_verbose_binary_op(f, "&", &that.inputs, self.input_schema)
372 }
373 ExprType::BitwiseOr => {
374 explain_verbose_binary_op(f, "|", &that.inputs, self.input_schema)
375 }
376 ExprType::BitwiseXor => {
377 explain_verbose_binary_op(f, "#", &that.inputs, self.input_schema)
378 }
379 ExprType::ArrayContains => {
380 explain_verbose_binary_op(f, "@>", &that.inputs, self.input_schema)
381 }
382 ExprType::ArrayContained => {
383 explain_verbose_binary_op(f, "<@", &that.inputs, self.input_schema)
384 }
385 ExprType::Proctime => {
386 write!(f, "{:?}", that.func_type)
387 }
388 _ => {
389 let func_name = format!("{:?}", that.func_type);
390 let mut builder = f.debug_tuple(&func_name);
391 that.inputs.iter().for_each(|child| {
392 builder.field(&ExprDisplay {
393 expr: child,
394 input_schema: self.input_schema,
395 });
396 });
397 builder.finish()
398 }
399 }
400 }
401}
402
403fn explain_verbose_binary_op(
404 f: &mut std::fmt::Formatter<'_>,
405 op: &str,
406 inputs: &[ExprImpl],
407 input_schema: &Schema,
408) -> std::fmt::Result {
409 use std::fmt::Debug;
410
411 assert_eq!(inputs.len(), 2);
412
413 write!(f, "(")?;
414 ExprDisplay {
415 expr: &inputs[0],
416 input_schema,
417 }
418 .fmt(f)?;
419 write!(f, " {} ", op)?;
420 ExprDisplay {
421 expr: &inputs[1],
422 input_schema,
423 }
424 .fmt(f)?;
425 write!(f, ")")?;
426
427 Ok(())
428}
429
430pub fn is_row_function(expr: &ExprImpl) -> bool {
431 if let ExprImpl::FunctionCall(func) = expr
432 && func.func_type() == ExprType::Row
433 {
434 return true;
435 }
436 false
437}