risingwave_frontend/expr/
agg_call.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::types::DataType;
16use risingwave_expr::aggregate::AggType;
17
18use super::{Expr, ExprImpl, Literal, OrderBy, infer_type};
19use crate::error::Result;
20use crate::utils::Condition;
21
22#[derive(Clone, Eq, PartialEq, Hash)]
23pub struct AggCall {
24    pub agg_type: AggType,
25    pub return_type: DataType,
26    pub args: Vec<ExprImpl>,
27    pub distinct: bool,
28    pub order_by: OrderBy,
29    pub filter: Condition,
30    pub direct_args: Vec<Literal>,
31}
32
33impl std::fmt::Debug for AggCall {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        if f.alternate() {
36            f.debug_struct("AggCall")
37                .field("agg_type", &self.agg_type)
38                .field("return_type", &self.return_type)
39                .field("args", &self.args)
40                .field("filter", &self.filter)
41                .field("distinct", &self.distinct)
42                .field("order_by", &self.order_by)
43                .field("direct_args", &self.direct_args)
44                .finish()
45        } else {
46            let mut builder = f.debug_tuple(&format!("{}", self.agg_type));
47            self.args.iter().for_each(|child| {
48                builder.field(child);
49            });
50            builder.finish()
51        }
52    }
53}
54
55impl AggCall {
56    /// Returns error if the function name matches with an existing function
57    /// but with illegal arguments.
58    pub fn new(
59        agg_type: AggType,
60        mut args: Vec<ExprImpl>,
61        distinct: bool,
62        order_by: OrderBy,
63        filter: Condition,
64        direct_args: Vec<Literal>,
65    ) -> Result<Self> {
66        let return_type = match &agg_type {
67            AggType::Builtin(kind) => infer_type((*kind).into(), &mut args)?,
68            AggType::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(),
69            AggType::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(),
70        };
71        Ok(AggCall {
72            agg_type,
73            return_type,
74            args,
75            distinct,
76            order_by,
77            filter,
78            direct_args,
79        })
80    }
81
82    /// Constructs an `AggCall` without type inference.
83    pub fn new_unchecked(
84        agg_type: AggType,
85        args: Vec<ExprImpl>,
86        return_type: DataType,
87    ) -> Result<Self> {
88        Ok(AggCall {
89            agg_type,
90            return_type,
91            args,
92            distinct: false,
93            order_by: OrderBy::any(),
94            filter: Condition::true_cond(),
95            direct_args: vec![],
96        })
97    }
98
99    pub fn agg_type(&self) -> AggType {
100        self.agg_type.clone()
101    }
102
103    /// Get a reference to the agg call's arguments.
104    pub fn args(&self) -> &[ExprImpl] {
105        self.args.as_ref()
106    }
107
108    pub fn args_mut(&mut self) -> &mut [ExprImpl] {
109        self.args.as_mut()
110    }
111
112    pub fn order_by(&self) -> &OrderBy {
113        &self.order_by
114    }
115
116    pub fn order_by_mut(&mut self) -> &mut OrderBy {
117        &mut self.order_by
118    }
119
120    pub fn filter(&self) -> &Condition {
121        &self.filter
122    }
123
124    pub fn filter_mut(&mut self) -> &mut Condition {
125        &mut self.filter
126    }
127}
128
129impl Expr for AggCall {
130    fn return_type(&self) -> DataType {
131        self.return_type.clone()
132    }
133
134    fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
135        // This function is always called on the physical planning step, where
136        // `ExprImpl::AggCall` must have been rewritten to aggregate operators.
137
138        unreachable!(
139            "AggCall {:?} has not been rewritten to physical aggregate operators",
140            self
141        )
142    }
143}