risingwave_frontend/expr/
function_call_with_lambda.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::types::DataType;
16
17use super::{ExprImpl, FunctionCall};
18use crate::expr::{Expr, ExprType};
19
20/// Similar to [`FunctionCall`], with an extra lambda function argument.
21#[derive(Clone, PartialEq, Eq, Hash)]
22pub struct FunctionCallWithLambda {
23    base: FunctionCall,
24    lambda_arg: ExprImpl,
25}
26
27impl std::fmt::Debug for FunctionCallWithLambda {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        if f.alternate() {
30            f.debug_struct("FunctionCallWithLambda")
31                .field("func_type", &self.base.func_type)
32                .field("return_type", &self.base.return_type)
33                .field("inputs", &self.base.inputs)
34                .field("lambda_arg", &self.lambda_arg)
35                .finish()
36        } else {
37            let func_name = format!("{:?}", self.base.func_type);
38            let mut builder = f.debug_tuple(&func_name);
39            for input in &self.base.inputs {
40                builder.field(input);
41            }
42            builder.field(&self.lambda_arg);
43            builder.finish()
44        }
45    }
46}
47
48impl FunctionCallWithLambda {
49    pub fn new_unchecked(
50        func_type: ExprType,
51        inputs: Vec<ExprImpl>,
52        lambda_arg: ExprImpl,
53        return_type: DataType,
54    ) -> Self {
55        assert!([ExprType::ArrayTransform].contains(&func_type));
56        Self {
57            base: FunctionCall::new_unchecked(func_type, inputs, return_type),
58            lambda_arg,
59        }
60    }
61
62    pub fn inputs(&self) -> &[ExprImpl] {
63        self.base.inputs()
64    }
65
66    pub fn func_type(&self) -> ExprType {
67        self.base.func_type()
68    }
69
70    pub fn return_type(&self) -> DataType {
71        self.base.return_type()
72    }
73
74    pub fn base(&self) -> &FunctionCall {
75        &self.base
76    }
77
78    pub fn base_mut(&mut self) -> &mut FunctionCall {
79        &mut self.base
80    }
81
82    pub fn inputs_with_lambda_arg(&self) -> impl Iterator<Item = &'_ ExprImpl> {
83        self.inputs().iter().chain([&self.lambda_arg])
84    }
85
86    pub fn to_full_function_call(&self) -> FunctionCall {
87        let full_inputs = self.inputs_with_lambda_arg().cloned();
88        FunctionCall::new_unchecked(self.func_type(), full_inputs.collect(), self.return_type())
89    }
90
91    pub fn into_parts(self) -> (ExprType, Vec<ExprImpl>, ExprImpl, DataType) {
92        let Self { base, lambda_arg } = self;
93        let (func_type, inputs, return_type) = base.decompose();
94        (func_type, inputs, lambda_arg, return_type)
95    }
96}
97
98impl Expr for FunctionCallWithLambda {
99    fn return_type(&self) -> DataType {
100        self.base.return_type()
101    }
102
103    fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
104        use risingwave_pb::expr::expr_node::*;
105        use risingwave_pb::expr::*;
106        ExprNode {
107            function_type: self.func_type().into(),
108            return_type: Some(self.return_type().to_protobuf()),
109            rex_node: Some(RexNode::FuncCall(FunctionCall {
110                children: self
111                    .inputs_with_lambda_arg()
112                    .map(Expr::to_expr_proto)
113                    .collect(),
114            })),
115        }
116    }
117}