risingwave_frontend/binder/expr/function/
aggregate.rs1use itertools::Itertools;
16use risingwave_common::types::{DataType, ScalarImpl};
17use risingwave_common::{bail, bail_not_implemented};
18use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
19use risingwave_sqlparser::ast::{self, FunctionArgExpr};
20
21use crate::Binder;
22use crate::binder::Clause;
23use crate::error::{ErrorCode, Result};
24use crate::expr::{AggCall, ExprImpl, Literal, OrderBy};
25use crate::utils::Condition;
26
27impl Binder {
28 fn ensure_aggregate_allowed(&self) -> Result<()> {
29 if let Some(clause) = self.context.clause {
30 match clause {
31 Clause::Where
32 | Clause::Values
33 | Clause::From
34 | Clause::GeneratedColumn
35 | Clause::Insert
36 | Clause::JoinOn => {
37 return Err(ErrorCode::InvalidInputSyntax(format!(
38 "aggregate functions are not allowed in {}",
39 clause
40 ))
41 .into());
42 }
43 Clause::Having | Clause::Filter | Clause::GroupBy => {}
44 }
45 }
46 Ok(())
47 }
48
49 pub(super) fn bind_aggregate_function(
50 &mut self,
51 agg_type: AggType,
52 distinct: bool,
53 args: Vec<ExprImpl>,
54 order_by: Vec<ast::OrderByExpr>,
55 within_group: Option<Box<ast::OrderByExpr>>,
56 filter: Option<Box<ast::Expr>>,
57 ) -> Result<ExprImpl> {
58 self.ensure_aggregate_allowed()?;
59
60 let (direct_args, args, order_by) = if matches!(agg_type, agg_types::ordered_set!()) {
61 self.bind_ordered_set_agg(&agg_type, distinct, args, order_by, within_group)?
62 } else {
63 self.bind_normal_agg(&agg_type, distinct, args, order_by, within_group)?
64 };
65
66 let filter = match filter {
67 Some(filter) => {
68 let mut clause = Some(Clause::Filter);
69 std::mem::swap(&mut self.context.clause, &mut clause);
70 let expr = self
71 .bind_expr_inner(*filter)
72 .and_then(|expr| expr.enforce_bool_clause("FILTER"))?;
73 self.context.clause = clause;
74 if expr.has_subquery() {
75 bail_not_implemented!("subquery in filter clause");
76 }
77 if expr.has_agg_call() {
78 bail_not_implemented!("aggregation function in filter clause");
79 }
80 if expr.has_table_function() {
81 bail_not_implemented!("table function in filter clause");
82 }
83 Condition::with_expr(expr)
84 }
85 None => Condition::true_cond(),
86 };
87
88 Ok(ExprImpl::AggCall(Box::new(AggCall::new(
89 agg_type,
90 args,
91 distinct,
92 order_by,
93 filter,
94 direct_args,
95 )?)))
96 }
97
98 fn bind_ordered_set_agg(
99 &mut self,
100 kind: &AggType,
101 distinct: bool,
102 args: Vec<ExprImpl>,
103 order_by: Vec<ast::OrderByExpr>,
104 within_group: Option<Box<ast::OrderByExpr>>,
105 ) -> Result<(Vec<Literal>, Vec<ExprImpl>, OrderBy)> {
106 assert!(matches!(kind, agg_types::ordered_set!()));
111
112 if !order_by.is_empty() {
113 return Err(ErrorCode::InvalidInputSyntax(format!(
114 "`ORDER BY` is not allowed for ordered-set aggregation `{}`",
115 kind
116 ))
117 .into());
118 }
119 if distinct {
120 return Err(ErrorCode::InvalidInputSyntax(format!(
121 "`DISTINCT` is not allowed for ordered-set aggregation `{}`",
122 kind
123 ))
124 .into());
125 }
126
127 let within_group = *within_group.ok_or_else(|| {
128 ErrorCode::InvalidInputSyntax(format!(
129 "`WITHIN GROUP` is expected for ordered-set aggregation `{}`",
130 kind
131 ))
132 })?;
133
134 let mut direct_args = args;
135 let mut args =
136 self.bind_function_expr_arg(FunctionArgExpr::Expr(within_group.expr.clone()))?;
137 let order_by = OrderBy::new(vec![self.bind_order_by_expr(within_group)?]);
138
139 match (kind, direct_args.len(), args.as_mut_slice()) {
141 (AggType::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), 1, [arg]) => {
142 let fraction = &mut direct_args[0];
143 decimal_to_float64(fraction, kind)?;
144 if matches!(&kind, AggType::Builtin(PbAggKind::PercentileCont)) {
145 arg.cast_implicit_mut(DataType::Float64).map_err(|_| {
146 ErrorCode::InvalidInputSyntax(format!(
147 "arg in `{}` must be castable to float64",
148 kind
149 ))
150 })?;
151 }
152 }
153 (AggType::Builtin(PbAggKind::Mode), 0, [_arg]) => {}
154 (AggType::Builtin(PbAggKind::ApproxPercentile), 1..=2, [_percentile_col]) => {
155 let percentile = &mut direct_args[0];
156 decimal_to_float64(percentile, kind)?;
157 match direct_args.len() {
158 2 => {
159 let relative_error = &mut direct_args[1];
160 decimal_to_float64(relative_error, kind)?;
161 if let Some(relative_error) = relative_error.as_literal()
162 && let Some(relative_error) = relative_error.get_data()
163 {
164 let relative_error = relative_error.as_float64().0;
165 if relative_error <= 0.0 || relative_error >= 1.0 {
166 bail!(
167 "relative_error={} does not satisfy 0.0 < relative_error < 1.0",
168 relative_error,
169 )
170 }
171 }
172 }
173 1 => {
174 let relative_error: ExprImpl = Literal::new(
175 ScalarImpl::Float64(0.01.into()).into(),
176 DataType::Float64,
177 )
178 .into();
179 direct_args.push(relative_error);
180 }
181 _ => {
182 return Err(ErrorCode::InvalidInputSyntax(
183 "invalid direct args for approx_percentile aggregation".to_owned(),
184 )
185 .into());
186 }
187 }
188 }
189 _ => {
190 return Err(ErrorCode::InvalidInputSyntax(format!(
191 "invalid direct args or within group argument for `{}` aggregation",
192 kind
193 ))
194 .into());
195 }
196 }
197
198 Ok((
199 direct_args
200 .into_iter()
201 .map(|arg| *arg.into_literal().unwrap())
202 .collect(),
203 args,
204 order_by,
205 ))
206 }
207
208 fn bind_normal_agg(
209 &mut self,
210 kind: &AggType,
211 distinct: bool,
212 args: Vec<ExprImpl>,
213 order_by: Vec<ast::OrderByExpr>,
214 within_group: Option<Box<ast::OrderByExpr>>,
215 ) -> Result<(Vec<Literal>, Vec<ExprImpl>, OrderBy)> {
216 assert!(!matches!(kind, agg_types::ordered_set!()));
226
227 if within_group.is_some() {
228 return Err(ErrorCode::InvalidInputSyntax(format!(
229 "`WITHIN GROUP` is not allowed for non-ordered-set aggregation `{}`",
230 kind
231 ))
232 .into());
233 }
234
235 let order_by = OrderBy::new(
236 order_by
237 .into_iter()
238 .map(|e| self.bind_order_by_expr(e))
239 .try_collect()?,
240 );
241
242 if distinct {
243 if matches!(
244 kind,
245 AggType::Builtin(PbAggKind::ApproxCountDistinct)
246 | AggType::Builtin(PbAggKind::ApproxPercentile)
247 ) {
248 return Err(ErrorCode::InvalidInputSyntax(format!(
249 "DISTINCT is not allowed for approximate aggregation `{}`",
250 kind
251 ))
252 .into());
253 }
254
255 if args.is_empty() {
256 return Err(ErrorCode::InvalidInputSyntax(format!(
257 "DISTINCT is not allowed for aggregate function `{}` without args",
258 kind
259 ))
260 .into());
261 }
262
263 if args.iter().skip(1).any(|arg| arg.as_literal().is_none()) {
266 bail_not_implemented!(
267 "non-constant arguments other than the first one for DISTINCT aggregation is not supported now"
268 );
269 }
270
271 if !order_by.sort_exprs.iter().all(|e| args.contains(&e.expr)) {
276 return Err(ErrorCode::InvalidInputSyntax(format!(
277 "ORDER BY expressions must match regular arguments of the aggregate for `{}` when DISTINCT is provided",
278 kind
279 ))
280 .into());
281 }
282 }
283
284 Ok((vec![], args, order_by))
285 }
286}
287
288fn decimal_to_float64(decimal_expr: &mut ExprImpl, kind: &AggType) -> Result<()> {
289 if decimal_expr.cast_implicit_mut(DataType::Float64).is_err() {
290 return Err(ErrorCode::InvalidInputSyntax(format!(
291 "direct arg in `{}` must be castable to float64",
292 kind
293 ))
294 .into());
295 }
296
297 let Some(Ok(fraction_datum)) = decimal_expr.try_fold_const() else {
298 bail_not_implemented!(
299 issue = 14079,
300 "variable as direct argument of ordered-set aggregate",
301 );
302 };
303
304 if let Some(ref fraction_value) = fraction_datum
305 && !(0.0..=1.0).contains(&fraction_value.as_float64().0)
306 {
307 return Err(ErrorCode::InvalidInputSyntax(format!(
308 "direct arg in `{}` must between 0.0 and 1.0",
309 kind
310 ))
311 .into());
312 }
313 *decimal_expr = Literal::new(fraction_datum, DataType::Float64).into();
315 Ok(())
316}