risingwave_frontend/binder/expr/function/
aggregate.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::types::{DataType, ScalarImpl};
17use risingwave_common::{bail, bail_not_implemented};
18use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
19use risingwave_sqlparser::ast::{self, FunctionArgExpr};
20
21use crate::Binder;
22use crate::binder::Clause;
23use crate::error::{ErrorCode, Result};
24use crate::expr::{AggCall, ExprImpl, Literal, OrderBy};
25use crate::utils::Condition;
26
27impl Binder {
28    fn ensure_aggregate_allowed(&self) -> Result<()> {
29        if let Some(clause) = self.context.clause {
30            match clause {
31                Clause::Where
32                | Clause::Values
33                | Clause::From
34                | Clause::GeneratedColumn
35                | Clause::Insert
36                | Clause::JoinOn => {
37                    return Err(ErrorCode::InvalidInputSyntax(format!(
38                        "aggregate functions are not allowed in {}",
39                        clause
40                    ))
41                    .into());
42                }
43                Clause::Having | Clause::Filter | Clause::GroupBy => {}
44            }
45        }
46        Ok(())
47    }
48
49    pub(super) fn bind_aggregate_function(
50        &mut self,
51        agg_type: AggType,
52        distinct: bool,
53        args: Vec<ExprImpl>,
54        order_by: Vec<ast::OrderByExpr>,
55        within_group: Option<Box<ast::OrderByExpr>>,
56        filter: Option<Box<ast::Expr>>,
57    ) -> Result<ExprImpl> {
58        self.ensure_aggregate_allowed()?;
59
60        let (direct_args, args, order_by) = if matches!(agg_type, agg_types::ordered_set!()) {
61            self.bind_ordered_set_agg(&agg_type, distinct, args, order_by, within_group)?
62        } else {
63            self.bind_normal_agg(&agg_type, distinct, args, order_by, within_group)?
64        };
65
66        let filter = match filter {
67            Some(filter) => {
68                let mut clause = Some(Clause::Filter);
69                std::mem::swap(&mut self.context.clause, &mut clause);
70                let expr = self
71                    .bind_expr_inner(*filter)
72                    .and_then(|expr| expr.enforce_bool_clause("FILTER"))?;
73                self.context.clause = clause;
74                if expr.has_subquery() {
75                    bail_not_implemented!("subquery in filter clause");
76                }
77                if expr.has_agg_call() {
78                    bail_not_implemented!("aggregation function in filter clause");
79                }
80                if expr.has_table_function() {
81                    bail_not_implemented!("table function in filter clause");
82                }
83                Condition::with_expr(expr)
84            }
85            None => Condition::true_cond(),
86        };
87
88        Ok(ExprImpl::AggCall(Box::new(AggCall::new(
89            agg_type,
90            args,
91            distinct,
92            order_by,
93            filter,
94            direct_args,
95        )?)))
96    }
97
98    fn bind_ordered_set_agg(
99        &mut self,
100        kind: &AggType,
101        distinct: bool,
102        args: Vec<ExprImpl>,
103        order_by: Vec<ast::OrderByExpr>,
104        within_group: Option<Box<ast::OrderByExpr>>,
105    ) -> Result<(Vec<Literal>, Vec<ExprImpl>, OrderBy)> {
106        // Syntax:
107        // aggregate_name ( [ expression [ , ... ] ] ) WITHIN GROUP ( order_by_clause ) [ FILTER
108        // ( WHERE filter_clause ) ]
109
110        assert!(matches!(kind, agg_types::ordered_set!()));
111
112        if !order_by.is_empty() {
113            return Err(ErrorCode::InvalidInputSyntax(format!(
114                "`ORDER BY` is not allowed for ordered-set aggregation `{}`",
115                kind
116            ))
117            .into());
118        }
119        if distinct {
120            return Err(ErrorCode::InvalidInputSyntax(format!(
121                "`DISTINCT` is not allowed for ordered-set aggregation `{}`",
122                kind
123            ))
124            .into());
125        }
126
127        let within_group = *within_group.ok_or_else(|| {
128            ErrorCode::InvalidInputSyntax(format!(
129                "`WITHIN GROUP` is expected for ordered-set aggregation `{}`",
130                kind
131            ))
132        })?;
133
134        let mut direct_args = args;
135        let mut args =
136            self.bind_function_expr_arg(FunctionArgExpr::Expr(within_group.expr.clone()))?;
137        let order_by = OrderBy::new(vec![self.bind_order_by_expr(within_group)?]);
138
139        // check signature and do implicit cast
140        match (kind, direct_args.len(), args.as_mut_slice()) {
141            (AggType::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), 1, [arg]) => {
142                let fraction = &mut direct_args[0];
143                decimal_to_float64(fraction, kind)?;
144                if matches!(&kind, AggType::Builtin(PbAggKind::PercentileCont)) {
145                    arg.cast_implicit_mut(DataType::Float64).map_err(|_| {
146                        ErrorCode::InvalidInputSyntax(format!(
147                            "arg in `{}` must be castable to float64",
148                            kind
149                        ))
150                    })?;
151                }
152            }
153            (AggType::Builtin(PbAggKind::Mode), 0, [_arg]) => {}
154            (AggType::Builtin(PbAggKind::ApproxPercentile), 1..=2, [_percentile_col]) => {
155                let percentile = &mut direct_args[0];
156                decimal_to_float64(percentile, kind)?;
157                match direct_args.len() {
158                    2 => {
159                        let relative_error = &mut direct_args[1];
160                        decimal_to_float64(relative_error, kind)?;
161                        if let Some(relative_error) = relative_error.as_literal()
162                            && let Some(relative_error) = relative_error.get_data()
163                        {
164                            let relative_error = relative_error.as_float64().0;
165                            if relative_error <= 0.0 || relative_error >= 1.0 {
166                                bail!(
167                                    "relative_error={} does not satisfy 0.0 < relative_error < 1.0",
168                                    relative_error,
169                                )
170                            }
171                        }
172                    }
173                    1 => {
174                        let relative_error: ExprImpl = Literal::new(
175                            ScalarImpl::Float64(0.01.into()).into(),
176                            DataType::Float64,
177                        )
178                        .into();
179                        direct_args.push(relative_error);
180                    }
181                    _ => {
182                        return Err(ErrorCode::InvalidInputSyntax(
183                            "invalid direct args for approx_percentile aggregation".to_owned(),
184                        )
185                        .into());
186                    }
187                }
188            }
189            _ => {
190                return Err(ErrorCode::InvalidInputSyntax(format!(
191                    "invalid direct args or within group argument for `{}` aggregation",
192                    kind
193                ))
194                .into());
195            }
196        }
197
198        Ok((
199            direct_args
200                .into_iter()
201                .map(|arg| *arg.into_literal().unwrap())
202                .collect(),
203            args,
204            order_by,
205        ))
206    }
207
208    fn bind_normal_agg(
209        &mut self,
210        kind: &AggType,
211        distinct: bool,
212        args: Vec<ExprImpl>,
213        order_by: Vec<ast::OrderByExpr>,
214        within_group: Option<Box<ast::OrderByExpr>>,
215    ) -> Result<(Vec<Literal>, Vec<ExprImpl>, OrderBy)> {
216        // Syntax:
217        // aggregate_name (expression [ , ... ] [ order_by_clause ] ) [ FILTER ( WHERE
218        //   filter_clause ) ]
219        // aggregate_name (ALL expression [ , ... ] [ order_by_clause ] ) [ FILTER ( WHERE
220        //   filter_clause ) ]
221        // aggregate_name (DISTINCT expression [ , ... ] [ order_by_clause ] ) [ FILTER ( WHERE
222        //   filter_clause ) ]
223        // aggregate_name ( * ) [ FILTER ( WHERE filter_clause ) ]
224
225        assert!(!matches!(kind, agg_types::ordered_set!()));
226
227        if within_group.is_some() {
228            return Err(ErrorCode::InvalidInputSyntax(format!(
229                "`WITHIN GROUP` is not allowed for non-ordered-set aggregation `{}`",
230                kind
231            ))
232            .into());
233        }
234
235        let order_by = OrderBy::new(
236            order_by
237                .into_iter()
238                .map(|e| self.bind_order_by_expr(e))
239                .try_collect()?,
240        );
241
242        if distinct {
243            if matches!(
244                kind,
245                AggType::Builtin(PbAggKind::ApproxCountDistinct)
246                    | AggType::Builtin(PbAggKind::ApproxPercentile)
247            ) {
248                return Err(ErrorCode::InvalidInputSyntax(format!(
249                    "DISTINCT is not allowed for approximate aggregation `{}`",
250                    kind
251                ))
252                .into());
253            }
254
255            if args.is_empty() {
256                return Err(ErrorCode::InvalidInputSyntax(format!(
257                    "DISTINCT is not allowed for aggregate function `{}` without args",
258                    kind
259                ))
260                .into());
261            }
262
263            // restrict arguments[1..] to be constant because we don't support multiple distinct key
264            // indices for now
265            if args.iter().skip(1).any(|arg| arg.as_literal().is_none()) {
266                bail_not_implemented!(
267                    "non-constant arguments other than the first one for DISTINCT aggregation is not supported now"
268                );
269            }
270
271            // restrict ORDER BY to align with PG, which says:
272            // > If DISTINCT is specified in addition to an order_by_clause, then all the ORDER BY
273            // > expressions must match regular arguments of the aggregate; that is, you cannot sort
274            // > on an expression that is not included in the DISTINCT list.
275            if !order_by.sort_exprs.iter().all(|e| args.contains(&e.expr)) {
276                return Err(ErrorCode::InvalidInputSyntax(format!(
277                    "ORDER BY expressions must match regular arguments of the aggregate for `{}` when DISTINCT is provided",
278                    kind
279                ))
280                .into());
281            }
282        }
283
284        Ok((vec![], args, order_by))
285    }
286}
287
288fn decimal_to_float64(decimal_expr: &mut ExprImpl, kind: &AggType) -> Result<()> {
289    if decimal_expr.cast_implicit_mut(DataType::Float64).is_err() {
290        return Err(ErrorCode::InvalidInputSyntax(format!(
291            "direct arg in `{}` must be castable to float64",
292            kind
293        ))
294        .into());
295    }
296
297    let Some(Ok(fraction_datum)) = decimal_expr.try_fold_const() else {
298        bail_not_implemented!(
299            issue = 14079,
300            "variable as direct argument of ordered-set aggregate",
301        );
302    };
303
304    if let Some(ref fraction_value) = fraction_datum
305        && !(0.0..=1.0).contains(&fraction_value.as_float64().0)
306    {
307        return Err(ErrorCode::InvalidInputSyntax(format!(
308            "direct arg in `{}` must between 0.0 and 1.0",
309            kind
310        ))
311        .into());
312    }
313    // note that the fraction can be NULL
314    *decimal_expr = Literal::new(fraction_datum, DataType::Float64).into();
315    Ok(())
316}