risingwave_frontend/expr/
function_call.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::catalog::Schema;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::iter_util::ZipEqFast;
19
20use super::type_inference::cast;
21use super::{CastContext, CastError, Expr, ExprImpl, Literal, infer_some_all, infer_type};
22use crate::error::Result as RwResult;
23use crate::expr::{ExprDisplay, ExprType, bail_cast_error, is_impure_func_call};
24
25#[derive(Clone, Eq, PartialEq, Hash)]
26pub struct FunctionCall {
27    pub(super) func_type: ExprType,
28    pub(super) return_type: DataType,
29    pub(super) inputs: Vec<ExprImpl>,
30}
31
32fn debug_binary_op(
33    f: &mut std::fmt::Formatter<'_>,
34    op: &str,
35    inputs: &[ExprImpl],
36) -> std::fmt::Result {
37    use std::fmt::Debug;
38
39    assert_eq!(inputs.len(), 2);
40
41    write!(f, "(")?;
42    inputs[0].fmt(f)?;
43    write!(f, " {} ", op)?;
44    inputs[1].fmt(f)?;
45    write!(f, ")")?;
46
47    Ok(())
48}
49
50impl std::fmt::Debug for FunctionCall {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        if f.alternate() {
53            f.debug_struct("FunctionCall")
54                .field("func_type", &self.func_type)
55                .field("return_type", &self.return_type)
56                .field("inputs", &self.inputs)
57                .finish()
58        } else {
59            match &self.func_type {
60                ExprType::Cast => {
61                    assert_eq!(self.inputs.len(), 1);
62                    self.inputs[0].fmt(f)?;
63                    write!(f, "::{:?}", self.return_type)
64                }
65                ExprType::Add => debug_binary_op(f, "+", &self.inputs),
66                ExprType::Subtract => debug_binary_op(f, "-", &self.inputs),
67                ExprType::Multiply => debug_binary_op(f, "*", &self.inputs),
68                ExprType::Divide => debug_binary_op(f, "/", &self.inputs),
69                ExprType::Modulus => debug_binary_op(f, "%", &self.inputs),
70                ExprType::Equal => debug_binary_op(f, "=", &self.inputs),
71                ExprType::NotEqual => debug_binary_op(f, "<>", &self.inputs),
72                ExprType::LessThan => debug_binary_op(f, "<", &self.inputs),
73                ExprType::LessThanOrEqual => debug_binary_op(f, "<=", &self.inputs),
74                ExprType::GreaterThan => debug_binary_op(f, ">", &self.inputs),
75                ExprType::GreaterThanOrEqual => debug_binary_op(f, ">=", &self.inputs),
76                ExprType::And => debug_binary_op(f, "AND", &self.inputs),
77                ExprType::Or => debug_binary_op(f, "OR", &self.inputs),
78                ExprType::BitwiseShiftLeft => debug_binary_op(f, "<<", &self.inputs),
79                ExprType::BitwiseShiftRight => debug_binary_op(f, ">>", &self.inputs),
80                ExprType::BitwiseAnd => debug_binary_op(f, "&", &self.inputs),
81                ExprType::BitwiseOr => debug_binary_op(f, "|", &self.inputs),
82                ExprType::BitwiseXor => debug_binary_op(f, "#", &self.inputs),
83                ExprType::ArrayContains => debug_binary_op(f, "@>", &self.inputs),
84                ExprType::ArrayContained => debug_binary_op(f, "<@", &self.inputs),
85                ExprType::ArrayOverlaps => debug_binary_op(f, "&&", &self.inputs),
86                _ => {
87                    let func_name = format!("{:?}", self.func_type);
88                    let mut builder = f.debug_tuple(&func_name);
89                    self.inputs.iter().for_each(|child| {
90                        builder.field(child);
91                    });
92                    builder.finish()
93                }
94            }
95        }
96    }
97}
98
99impl FunctionCall {
100    /// Create a `FunctionCall` expr with the return type inferred from `func_type` and types of
101    /// `inputs`.
102    // The functions listed here are all variadic.  Type signatures of functions that take a fixed
103    // number of arguments are checked
104    // [elsewhere](crate::expr::type_inference::build_type_derive_map).
105    pub fn new(func_type: ExprType, mut inputs: Vec<ExprImpl>) -> RwResult<Self> {
106        let return_type = infer_type(func_type.into(), &mut inputs)?;
107        Ok(Self::new_unchecked(func_type, inputs, return_type))
108    }
109
110    /// Create a cast expr over `child` to `target` type in `allows` context.
111    /// The input `child` remains unchanged when this returns an error.
112    pub fn cast_mut(
113        child: &mut ExprImpl,
114        target: &DataType,
115        allows: CastContext,
116    ) -> Result<(), CastError> {
117        if let ExprImpl::Parameter(expr) = child
118            && !expr.has_infer()
119        {
120            // Always Ok below. Safe to mutate `expr` (from `child`).
121            expr.cast_infer_type(target);
122            return Ok(());
123        }
124        if let ExprImpl::FunctionCall(func) = child
125            && func.func_type == ExprType::Row
126        {
127            // Row function will have empty fields in Datatype::Struct at this point. Therefore,
128            // we will need to take some special care to generate the cast types. For normal struct
129            // types, they will be handled in `cast_ok`.
130            return Self::cast_row_expr(func, target, allows);
131        }
132        if child.is_untyped() {
133            // `is_unknown` makes sure `as_literal` and `as_utf8` will never panic.
134            let literal = child.as_literal().unwrap();
135            let datum = literal
136                .get_data()
137                .as_ref()
138                .map(|scalar| ScalarImpl::from_text(scalar.as_utf8(), target))
139                .transpose();
140            if let Ok(datum) = datum {
141                *child = Literal::new(datum, target.clone()).into();
142                return Ok(());
143            }
144            // else when eager parsing fails, just proceed as normal.
145            // Some callers are not ready to handle `'a'::int` error here.
146        }
147
148        let source = child.return_type();
149        if &source == target {
150            return Ok(());
151        }
152
153        if child.is_untyped() {
154            // Casting from unknown is allowed in all context. And PostgreSQL actually does the parsing
155            // in frontend.
156        } else {
157            cast(&source, target, allows)?;
158        }
159
160        // Always Ok below. Safe to mutate `child`.
161        let owned = std::mem::replace(child, ExprImpl::literal_bool(false));
162        *child = Self::new_unchecked(ExprType::Cast, vec![owned], target.clone()).into();
163        Ok(())
164    }
165
166    /// Cast a `ROW` expression to the target type. We intentionally disallow casting arbitrary
167    /// expressions, like `ROW(1)::STRUCT<i INTEGER>` to `STRUCT<VARCHAR>`, although an integer
168    /// is castable to VARCHAR. It's to simply the casting rules.
169    fn cast_row_expr(
170        func: &mut FunctionCall,
171        target_type: &DataType,
172        allows: CastContext,
173    ) -> Result<(), CastError> {
174        // Can only cast to a struct type.
175        let DataType::Struct(t) = &target_type else {
176            bail_cast_error!(
177                "cannot cast type \"{}\" to \"{}\"",
178                func.return_type(), // typically "record"
179                target_type,
180            );
181        };
182
183        let expected_len = t.len();
184        let actual_len = func.inputs.len();
185
186        match expected_len.cmp(&actual_len) {
187            std::cmp::Ordering::Equal => {
188                // FIXME: `func` shall not be in a partially mutated state when one of its fields
189                // fails to cast.
190                func.inputs
191                    .iter_mut()
192                    .zip_eq_fast(t.types())
193                    .try_for_each(|(e, t)| Self::cast_mut(e, t, allows))?;
194                func.return_type = target_type.clone();
195                Ok(())
196            }
197            std::cmp::Ordering::Less => bail_cast_error!(
198                "input has too many columns, expected {expected_len} but got {actual_len}"
199            ),
200            std::cmp::Ordering::Greater => bail_cast_error!(
201                "input has too few columns, expected {expected_len} but got {actual_len}"
202            ),
203        }
204    }
205
206    /// Construct a `FunctionCall` expr directly with the provided `return_type`, bypassing type
207    /// inference. Use with caution.
208    pub fn new_unchecked(
209        func_type: ExprType,
210        inputs: Vec<ExprImpl>,
211        return_type: DataType,
212    ) -> Self {
213        FunctionCall {
214            func_type,
215            return_type,
216            inputs,
217        }
218    }
219
220    pub fn new_binary_op_func(
221        mut func_types: Vec<ExprType>,
222        mut inputs: Vec<ExprImpl>,
223    ) -> RwResult<ExprImpl> {
224        let expr_type = func_types.remove(0);
225        match expr_type {
226            ExprType::Some | ExprType::All => {
227                let return_type = infer_some_all(func_types, &mut inputs)?;
228                Ok(FunctionCall::new_unchecked(expr_type, inputs, return_type).into())
229            }
230            ExprType::Not | ExprType::IsNotNull | ExprType::IsNull | ExprType::Neg => {
231                Ok(FunctionCall::new(
232                    expr_type,
233                    vec![Self::new_binary_op_func(func_types, inputs)?],
234                )?
235                .into())
236            }
237            _ => Ok(FunctionCall::new(expr_type, inputs)?.into()),
238        }
239    }
240
241    pub fn decompose(self) -> (ExprType, Vec<ExprImpl>, DataType) {
242        (self.func_type, self.inputs, self.return_type)
243    }
244
245    pub fn decompose_as_binary(self) -> (ExprType, ExprImpl, ExprImpl) {
246        assert_eq!(self.inputs.len(), 2);
247        let mut iter = self.inputs.into_iter();
248        let left = iter.next().unwrap();
249        let right = iter.next().unwrap();
250        (self.func_type, left, right)
251    }
252
253    pub fn decompose_as_unary(self) -> (ExprType, ExprImpl) {
254        assert_eq!(self.inputs.len(), 1);
255        let mut iter = self.inputs.into_iter();
256        let input = iter.next().unwrap();
257        (self.func_type, input)
258    }
259
260    pub fn func_type(&self) -> ExprType {
261        self.func_type
262    }
263
264    /// Get a reference to the function call's inputs.
265    pub fn inputs(&self) -> &[ExprImpl] {
266        self.inputs.as_ref()
267    }
268
269    pub fn inputs_mut(&mut self) -> &mut [ExprImpl] {
270        self.inputs.as_mut()
271    }
272
273    pub(super) fn from_expr_proto(
274        function_call: &risingwave_pb::expr::FunctionCall,
275        func_type: ExprType,
276        return_type: DataType,
277    ) -> RwResult<Self> {
278        let inputs: Vec<_> = function_call
279            .get_children()
280            .iter()
281            .map(ExprImpl::from_expr_proto)
282            .try_collect()?;
283        Ok(Self {
284            func_type,
285            return_type,
286            inputs,
287        })
288    }
289
290    pub fn is_pure(&self) -> bool {
291        !is_impure_func_call(self)
292    }
293}
294
295impl Expr for FunctionCall {
296    fn return_type(&self) -> DataType {
297        self.return_type.clone()
298    }
299
300    fn try_to_expr_proto(&self) -> Result<risingwave_pb::expr::ExprNode, String> {
301        use risingwave_pb::expr::expr_node::*;
302        use risingwave_pb::expr::*;
303
304        let children = self
305            .inputs()
306            .iter()
307            .map(|input| input.try_to_expr_proto())
308            .try_collect()?;
309
310        Ok(ExprNode {
311            function_type: self.func_type().into(),
312            return_type: Some(self.return_type().to_protobuf()),
313            rex_node: Some(RexNode::FuncCall(FunctionCall { children })),
314        })
315    }
316}
317
318pub struct FunctionCallDisplay<'a> {
319    pub function_call: &'a FunctionCall,
320    pub input_schema: &'a Schema,
321}
322
323impl std::fmt::Debug for FunctionCallDisplay<'_> {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        let that = self.function_call;
326        match &that.func_type {
327            ExprType::Cast => {
328                assert_eq!(that.inputs.len(), 1);
329                ExprDisplay {
330                    expr: &that.inputs[0],
331                    input_schema: self.input_schema,
332                }
333                .fmt(f)?;
334                write!(f, "::{:?}", that.return_type)
335            }
336            ExprType::Add => explain_verbose_binary_op(f, "+", &that.inputs, self.input_schema),
337            ExprType::Subtract => {
338                explain_verbose_binary_op(f, "-", &that.inputs, self.input_schema)
339            }
340            ExprType::Multiply => {
341                explain_verbose_binary_op(f, "*", &that.inputs, self.input_schema)
342            }
343            ExprType::Divide => explain_verbose_binary_op(f, "/", &that.inputs, self.input_schema),
344            ExprType::Modulus => explain_verbose_binary_op(f, "%", &that.inputs, self.input_schema),
345            ExprType::Equal => explain_verbose_binary_op(f, "=", &that.inputs, self.input_schema),
346            ExprType::NotEqual => {
347                explain_verbose_binary_op(f, "<>", &that.inputs, self.input_schema)
348            }
349            ExprType::LessThan => {
350                explain_verbose_binary_op(f, "<", &that.inputs, self.input_schema)
351            }
352            ExprType::LessThanOrEqual => {
353                explain_verbose_binary_op(f, "<=", &that.inputs, self.input_schema)
354            }
355            ExprType::GreaterThan => {
356                explain_verbose_binary_op(f, ">", &that.inputs, self.input_schema)
357            }
358            ExprType::GreaterThanOrEqual => {
359                explain_verbose_binary_op(f, ">=", &that.inputs, self.input_schema)
360            }
361            ExprType::And => explain_verbose_binary_op(f, "AND", &that.inputs, self.input_schema),
362            ExprType::Or => explain_verbose_binary_op(f, "OR", &that.inputs, self.input_schema),
363            ExprType::BitwiseShiftLeft => {
364                explain_verbose_binary_op(f, "<<", &that.inputs, self.input_schema)
365            }
366            ExprType::BitwiseShiftRight => {
367                explain_verbose_binary_op(f, ">>", &that.inputs, self.input_schema)
368            }
369            ExprType::BitwiseAnd => {
370                explain_verbose_binary_op(f, "&", &that.inputs, self.input_schema)
371            }
372            ExprType::BitwiseOr => {
373                explain_verbose_binary_op(f, "|", &that.inputs, self.input_schema)
374            }
375            ExprType::BitwiseXor => {
376                explain_verbose_binary_op(f, "#", &that.inputs, self.input_schema)
377            }
378            ExprType::ArrayContains => {
379                explain_verbose_binary_op(f, "@>", &that.inputs, self.input_schema)
380            }
381            ExprType::ArrayContained => {
382                explain_verbose_binary_op(f, "<@", &that.inputs, self.input_schema)
383            }
384            ExprType::ArrayOverlaps => {
385                explain_verbose_binary_op(f, "&&", &that.inputs, self.input_schema)
386            }
387            ExprType::Proctime => {
388                write!(f, "{:?}", that.func_type)
389            }
390            _ => {
391                let func_name = format!("{:?}", that.func_type);
392                let mut builder = f.debug_tuple(&func_name);
393                that.inputs.iter().for_each(|child| {
394                    builder.field(&ExprDisplay {
395                        expr: child,
396                        input_schema: self.input_schema,
397                    });
398                });
399                builder.finish()
400            }
401        }
402    }
403}
404
405fn explain_verbose_binary_op(
406    f: &mut std::fmt::Formatter<'_>,
407    op: &str,
408    inputs: &[ExprImpl],
409    input_schema: &Schema,
410) -> std::fmt::Result {
411    use std::fmt::Debug;
412
413    assert_eq!(inputs.len(), 2);
414
415    write!(f, "(")?;
416    ExprDisplay {
417        expr: &inputs[0],
418        input_schema,
419    }
420    .fmt(f)?;
421    write!(f, " {} ", op)?;
422    ExprDisplay {
423        expr: &inputs[1],
424        input_schema,
425    }
426    .fmt(f)?;
427    write!(f, ")")?;
428
429    Ok(())
430}
431
432pub fn is_row_function(expr: &ExprImpl) -> bool {
433    if let ExprImpl::FunctionCall(func) = expr
434        && func.func_type() == ExprType::Row
435    {
436        return true;
437    }
438    false
439}