risingwave_expr/aggregate/
def.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Aggregation function definitions.
16
17use std::fmt::Display;
18use std::iter::Peekable;
19use std::str::FromStr;
20use std::sync::Arc;
21
22use anyhow::Context;
23use enum_as_inner::EnumAsInner;
24use itertools::Itertools;
25use risingwave_common::bail;
26use risingwave_common::types::{DataType, Datum};
27use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
28use risingwave_common::util::value_encoding::DatumFromProtoExt;
29pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
30use risingwave_pb::expr::{
31    PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
32};
33
34use crate::Result;
35use crate::expr::{
36    BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, build_from_prost,
37};
38
39/// Represents an aggregation function.
40// TODO(runji):
41//  remove this struct from the expression module.
42//  this module only cares about aggregate functions themselves.
43//  advanced features like order by, filter, distinct, etc. should be handled by the upper layer.
44#[derive(Debug, Clone)]
45pub struct AggCall {
46    /// Aggregation type for constructing agg state.
47    pub agg_type: AggType,
48
49    /// Arguments of aggregation function input.
50    pub args: AggArgs,
51
52    /// The return type of aggregation function.
53    pub return_type: DataType,
54
55    /// Order requirements specified in order by clause of agg call
56    pub column_orders: Vec<ColumnOrder>,
57
58    /// Filter of aggregation.
59    pub filter: Option<Arc<dyn Expression>>,
60
61    /// Should deduplicate the input before aggregation.
62    pub distinct: bool,
63
64    /// Constant arguments.
65    pub direct_args: Vec<LiteralExpression>,
66}
67
68impl AggCall {
69    pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
70        let agg_type = AggType::from_protobuf_flatten(
71            agg_call.get_kind()?,
72            agg_call.udf.as_ref(),
73            agg_call.scalar.as_ref(),
74        )?;
75        let args = AggArgs::from_protobuf(agg_call.get_args())?;
76        let column_orders = agg_call
77            .get_order_by()
78            .iter()
79            .map(|col_order| {
80                let col_idx = col_order.get_column_index() as usize;
81                let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap());
82                ColumnOrder::new(col_idx, order_type)
83            })
84            .collect();
85        let filter = match agg_call.filter {
86            Some(ref pb_filter) => Some(build_from_prost(pb_filter)?.into()), /* TODO: non-strict filter in streaming */
87            None => None,
88        };
89        let direct_args = agg_call
90            .direct_args
91            .iter()
92            .map(|arg| {
93                let data_type = DataType::from(arg.get_type().unwrap());
94                LiteralExpression::new(
95                    data_type.clone(),
96                    Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(),
97                )
98            })
99            .collect_vec();
100        Ok(AggCall {
101            agg_type,
102            args,
103            return_type: DataType::from(agg_call.get_return_type()?),
104            column_orders,
105            filter,
106            distinct: agg_call.distinct,
107            direct_args,
108        })
109    }
110
111    /// Build an `AggCall` from a string.
112    ///
113    /// # Syntax
114    ///
115    /// ```text
116    /// (<name>:<type> [<index>:<type>]* [distinct] [orderby [<index>:<asc|desc>]*])
117    /// ```
118    pub fn from_pretty(s: impl AsRef<str>) -> Self {
119        let tokens = crate::expr::lexer(s.as_ref());
120        Parser::new(tokens.into_iter()).parse_aggregation()
121    }
122
123    pub fn with_filter(mut self, filter: BoxedExpression) -> Self {
124        self.filter = Some(filter.into());
125        self
126    }
127}
128
129struct Parser<Iter: Iterator> {
130    tokens: Peekable<Iter>,
131}
132
133impl<Iter: Iterator<Item = Token>> Parser<Iter> {
134    fn new(tokens: Iter) -> Self {
135        Self {
136            tokens: tokens.peekable(),
137        }
138    }
139
140    fn parse_aggregation(&mut self) -> AggCall {
141        assert_eq!(self.tokens.next(), Some(Token::LParen), "Expected a (");
142        let func = self.parse_function();
143        assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
144        let ty = self.parse_type();
145
146        let mut distinct = false;
147        let mut children = Vec::new();
148        let mut column_orders = Vec::new();
149        while matches!(self.tokens.peek(), Some(Token::Index(_))) {
150            children.push(self.parse_arg());
151        }
152        if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "distinct") {
153            distinct = true;
154            self.tokens.next(); // Consume
155        }
156        if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "orderby") {
157            self.tokens.next(); // Consume
158            while matches!(self.tokens.peek(), Some(Token::Index(_))) {
159                column_orders.push(self.parse_orderkey());
160            }
161        }
162        self.tokens.next(); // Consume the RParen
163
164        AggCall {
165            agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
166            args: AggArgs {
167                data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
168                val_indices: children.iter().map(|(idx, _)| *idx).collect(),
169            },
170            return_type: ty,
171            column_orders,
172            filter: None,
173            distinct,
174            direct_args: Vec::new(),
175        }
176    }
177
178    fn parse_type(&mut self) -> DataType {
179        match self.tokens.next().expect("Unexpected end of input") {
180            Token::Literal(name) => name.parse::<DataType>().expect_str("type", &name),
181            t => panic!("Expected a Literal, got {t:?}"),
182        }
183    }
184
185    fn parse_arg(&mut self) -> (usize, DataType) {
186        let idx = match self.tokens.next().expect("Unexpected end of input") {
187            Token::Index(idx) => idx,
188            t => panic!("Expected an Index, got {t:?}"),
189        };
190        assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
191        let ty = self.parse_type();
192        (idx, ty)
193    }
194
195    fn parse_function(&mut self) -> PbAggKind {
196        match self.tokens.next().expect("Unexpected end of input") {
197            Token::Literal(name) => {
198                PbAggKind::from_str_name(&name.to_uppercase()).expect_str("function", &name)
199            }
200            t => panic!("Expected a Literal, got {t:?}"),
201        }
202    }
203
204    fn parse_orderkey(&mut self) -> ColumnOrder {
205        let idx = match self.tokens.next().expect("Unexpected end of input") {
206            Token::Index(idx) => idx,
207            t => panic!("Expected an Index, got {t:?}"),
208        };
209        assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
210        let order = match self.tokens.next().expect("Unexpected end of input") {
211            Token::Literal(s) if s == "asc" => OrderType::ascending(),
212            Token::Literal(s) if s == "desc" => OrderType::descending(),
213            t => panic!("Expected asc or desc, got {t:?}"),
214        };
215        ColumnOrder::new(idx, order)
216    }
217}
218
219/// Aggregate function kind.
220#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumAsInner)]
221pub enum AggType {
222    /// Built-in aggregate function.
223    ///
224    /// The associated value should not be `UserDefined` or `WrapScalar`.
225    Builtin(PbAggKind),
226
227    /// User defined aggregate function.
228    UserDefined(PbUserDefinedFunctionMetadata),
229
230    /// Wrap a scalar function that takes a list as input as an aggregation function.
231    WrapScalar(PbExprNode),
232}
233
234impl Display for AggType {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        match self {
237            Self::Builtin(kind) => write!(f, "{}", kind.as_str_name().to_lowercase()),
238            Self::UserDefined(_) => write!(f, "udaf"),
239            Self::WrapScalar(_) => write!(f, "wrap_scalar"),
240        }
241    }
242}
243
244/// `FromStr` for builtin aggregate functions.
245impl FromStr for AggType {
246    type Err = ();
247
248    fn from_str(s: &str) -> Result<Self, Self::Err> {
249        let kind = PbAggKind::from_str(s)?;
250        Ok(AggType::Builtin(kind))
251    }
252}
253
254impl From<PbAggKind> for AggType {
255    fn from(pb: PbAggKind) -> Self {
256        assert!(!matches!(
257            pb,
258            PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar
259        ));
260        AggType::Builtin(pb)
261    }
262}
263
264impl AggType {
265    pub fn from_protobuf_flatten(
266        pb_kind: PbAggKind,
267        user_defined: Option<&PbUserDefinedFunctionMetadata>,
268        scalar: Option<&PbExprNode>,
269    ) -> Result<Self> {
270        match pb_kind {
271            PbAggKind::UserDefined => {
272                let user_defined = user_defined.context("expect user defined")?;
273                Ok(AggType::UserDefined(user_defined.clone()))
274            }
275            PbAggKind::WrapScalar => {
276                let scalar = scalar.context("expect scalar")?;
277                Ok(AggType::WrapScalar(scalar.clone()))
278            }
279            PbAggKind::Unspecified => bail!("Unrecognized agg."),
280            _ => Ok(AggType::Builtin(pb_kind)),
281        }
282    }
283
284    pub fn to_protobuf_simple(&self) -> PbAggKind {
285        match self {
286            Self::Builtin(pb) => *pb,
287            Self::UserDefined(_) => PbAggKind::UserDefined,
288            Self::WrapScalar(_) => PbAggKind::WrapScalar,
289        }
290    }
291
292    pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
293        match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
294            PbAggKind::Unspecified => bail!("Unrecognized agg."),
295            PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
296            PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
297            kind => Ok(AggType::Builtin(kind)),
298        }
299    }
300
301    pub fn to_protobuf(&self) -> PbAggType {
302        match self {
303            Self::Builtin(kind) => PbAggType {
304                kind: *kind as _,
305                udf_meta: None,
306                scalar_expr: None,
307            },
308            Self::UserDefined(udf_meta) => PbAggType {
309                kind: PbAggKind::UserDefined as _,
310                udf_meta: Some(udf_meta.clone()),
311                scalar_expr: None,
312            },
313            Self::WrapScalar(scalar_expr) => PbAggType {
314                kind: PbAggKind::WrapScalar as _,
315                udf_meta: None,
316                scalar_expr: Some(scalar_expr.clone()),
317            },
318        }
319    }
320}
321
322/// Macros to generate match arms for `AggType`.
323/// IMPORTANT: These macros must be carefully maintained especially when adding new
324/// `AggType`/`PbAggKind` variants.
325pub mod agg_types {
326    /// [`AggType`](super::AggType)s that should've been rewritten to other kinds. These kinds
327    /// should not appear when generating physical plan nodes.
328    #[macro_export]
329    macro_rules! rewritten {
330        () => {
331            AggType::Builtin(
332                PbAggKind::Avg
333                    | PbAggKind::StddevPop
334                    | PbAggKind::StddevSamp
335                    | PbAggKind::VarPop
336                    | PbAggKind::VarSamp
337                    | PbAggKind::Grouping
338                    // ApproxPercentile always uses custom agg executors,
339                    // rather than an aggregation operator
340                    | PbAggKind::ApproxPercentile
341                    | PbAggKind::ArgMin
342                    | PbAggKind::ArgMax
343            )
344        };
345    }
346    pub use rewritten;
347
348    /// [`AggType`](super::AggType)s of which the aggregate results are not affected by the
349    /// user given ORDER BY clause.
350    #[macro_export]
351    macro_rules! result_unaffected_by_order_by {
352        () => {
353            AggType::Builtin(PbAggKind::BitAnd
354                | PbAggKind::BitOr
355                | PbAggKind::BitXor // XOR is commutative and associative
356                | PbAggKind::BoolAnd
357                | PbAggKind::BoolOr
358                | PbAggKind::Min
359                | PbAggKind::Max
360                | PbAggKind::Sum
361                | PbAggKind::Sum0
362                | PbAggKind::Count
363                | PbAggKind::Avg
364                | PbAggKind::ApproxCountDistinct
365                | PbAggKind::VarPop
366                | PbAggKind::VarSamp
367                | PbAggKind::StddevPop
368                | PbAggKind::StddevSamp)
369        };
370    }
371    pub use result_unaffected_by_order_by;
372
373    /// [`AggType`](super::AggType)s that must be called with ORDER BY clause. These are
374    /// slightly different from variants not in [`result_unaffected_by_order_by`], in that
375    /// variants returned by this macro should be banned while the others should just be warned.
376    #[macro_export]
377    macro_rules! must_have_order_by {
378        () => {
379            AggType::Builtin(
380                PbAggKind::FirstValue
381                    | PbAggKind::LastValue
382                    | PbAggKind::PercentileCont
383                    | PbAggKind::PercentileDisc
384                    | PbAggKind::Mode,
385            )
386        };
387    }
388    pub use must_have_order_by;
389
390    /// [`AggType`](super::AggType)s of which the aggregate results are not affected by the
391    /// user given DISTINCT keyword.
392    #[macro_export]
393    macro_rules! result_unaffected_by_distinct {
394        () => {
395            AggType::Builtin(
396                PbAggKind::BitAnd
397                    | PbAggKind::BitOr
398                    | PbAggKind::BoolAnd
399                    | PbAggKind::BoolOr
400                    | PbAggKind::Min
401                    | PbAggKind::Max
402                    | PbAggKind::ApproxCountDistinct,
403            )
404        };
405    }
406    pub use result_unaffected_by_distinct;
407
408    /// [`AggType`](crate::aggregate::AggType)s that are simply cannot 2-phased.
409    #[macro_export]
410    macro_rules! simply_cannot_two_phase {
411        () => {
412            AggType::Builtin(
413                PbAggKind::StringAgg
414                    | PbAggKind::ApproxCountDistinct
415                    | PbAggKind::ArrayAgg
416                    | PbAggKind::JsonbAgg
417                    | PbAggKind::JsonbObjectAgg
418                    | PbAggKind::FirstValue
419                    | PbAggKind::LastValue
420                    | PbAggKind::PercentileCont
421                    | PbAggKind::PercentileDisc
422                    | PbAggKind::Mode
423                    // FIXME(wrj): move `BoolAnd` and `BoolOr` out
424                    //  after we support general merge in stateless_simple_agg
425                    | PbAggKind::BoolAnd
426                    | PbAggKind::BoolOr
427                    | PbAggKind::BitAnd
428                    | PbAggKind::BitOr
429            )
430            | AggType::UserDefined(_)
431            | AggType::WrapScalar(_)
432        };
433    }
434    pub use simply_cannot_two_phase;
435
436    /// [`AggType`](super::AggType)s that are implemented with a single value state (so-called
437    /// stateless).
438    #[macro_export]
439    macro_rules! single_value_state {
440        () => {
441            AggType::Builtin(
442                PbAggKind::Sum
443                    | PbAggKind::Sum0
444                    | PbAggKind::Count
445                    | PbAggKind::BitAnd
446                    | PbAggKind::BitOr
447                    | PbAggKind::BitXor
448                    | PbAggKind::BoolAnd
449                    | PbAggKind::BoolOr
450                    | PbAggKind::ApproxCountDistinct
451                    | PbAggKind::InternalLastSeenValue
452                    | PbAggKind::ApproxPercentile,
453            ) | AggType::UserDefined(_)
454        };
455    }
456    pub use single_value_state;
457
458    /// [`AggType`](super::AggType)s that are implemented with a single value state (so-called
459    /// stateless) iff the input is append-only.
460    #[macro_export]
461    macro_rules! single_value_state_iff_in_append_only {
462        () => {
463            AggType::Builtin(PbAggKind::Max | PbAggKind::Min)
464        };
465    }
466    pub use single_value_state_iff_in_append_only;
467
468    /// [`AggType`](super::AggType)s that are implemented with a materialized input state.
469    #[macro_export]
470    macro_rules! materialized_input_state {
471        () => {
472            AggType::Builtin(
473                PbAggKind::Min
474                    | PbAggKind::Max
475                    | PbAggKind::FirstValue
476                    | PbAggKind::LastValue
477                    | PbAggKind::StringAgg
478                    | PbAggKind::ArrayAgg
479                    | PbAggKind::JsonbAgg
480                    | PbAggKind::JsonbObjectAgg
481                    | PbAggKind::PercentileCont
482                    | PbAggKind::PercentileDisc
483                    | PbAggKind::Mode,
484            ) | AggType::WrapScalar(_)
485        };
486    }
487    pub use materialized_input_state;
488
489    /// Ordered-set aggregate functions.
490    #[macro_export]
491    macro_rules! ordered_set {
492        () => {
493            AggType::Builtin(
494                PbAggKind::PercentileCont
495                    | PbAggKind::PercentileDisc
496                    | PbAggKind::Mode
497                    | PbAggKind::ApproxPercentile,
498            )
499        };
500    }
501    pub use ordered_set;
502}
503
504impl AggType {
505    /// Get the total phase agg kind from the partial phase agg kind.
506    pub fn partial_to_total(&self) -> Option<Self> {
507        match self {
508            AggType::Builtin(
509                PbAggKind::BitXor
510                | PbAggKind::Min
511                | PbAggKind::Max
512                | PbAggKind::Sum
513                | PbAggKind::InternalLastSeenValue,
514            ) => Some(self.clone()),
515            AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => {
516                Some(Self::Builtin(PbAggKind::Sum0))
517            }
518            agg_types::simply_cannot_two_phase!() => None,
519            agg_types::rewritten!() => None,
520            // invalid variants
521            AggType::Builtin(
522                PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
523            ) => None,
524        }
525    }
526}
527
528/// An aggregation function may accept 0, 1 or 2 arguments.
529#[derive(Clone, Debug, Default)]
530pub struct AggArgs {
531    data_types: Box<[DataType]>,
532    val_indices: Box<[usize]>,
533}
534
535impl AggArgs {
536    pub fn from_protobuf(args: &[PbInputRef]) -> Result<Self> {
537        Ok(AggArgs {
538            data_types: args
539                .iter()
540                .map(|arg| DataType::from(arg.get_type().unwrap()))
541                .collect(),
542            val_indices: args.iter().map(|arg| arg.get_index() as usize).collect(),
543        })
544    }
545
546    /// return the types of arguments.
547    pub fn arg_types(&self) -> &[DataType] {
548        &self.data_types
549    }
550
551    /// return the indices of the arguments in [`risingwave_common::array::StreamChunk`].
552    pub fn val_indices(&self) -> &[usize] {
553        &self.val_indices
554    }
555}
556
557impl FromIterator<(DataType, usize)> for AggArgs {
558    fn from_iter<T: IntoIterator<Item = (DataType, usize)>>(iter: T) -> Self {
559        let (data_types, val_indices): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
560        AggArgs {
561            data_types: data_types.into(),
562            val_indices: val_indices.into(),
563        }
564    }
565}