risingwave_expr/aggregate/
def.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Aggregation function definitions.
16
17use std::fmt::Display;
18use std::iter::Peekable;
19use std::str::FromStr;
20use std::sync::Arc;
21
22use anyhow::Context;
23use enum_as_inner::EnumAsInner;
24use itertools::Itertools;
25use risingwave_common::bail;
26use risingwave_common::types::{DataType, Datum};
27use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
28use risingwave_common::util::value_encoding::DatumFromProtoExt;
29pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
30use risingwave_pb::expr::{
31    PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
32};
33
34use crate::Result;
35use crate::expr::{
36    BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, build_from_prost,
37};
38
39/// Represents an aggregation function.
40// TODO(runji):
41//  remove this struct from the expression module.
42//  this module only cares about aggregate functions themselves.
43//  advanced features like order by, filter, distinct, etc. should be handled by the upper layer.
44#[derive(Debug, Clone)]
45pub struct AggCall {
46    /// Aggregation type for constructing agg state.
47    pub agg_type: AggType,
48
49    /// Arguments of aggregation function input.
50    pub args: AggArgs,
51
52    /// The return type of aggregation function.
53    pub return_type: DataType,
54
55    /// Order requirements specified in order by clause of agg call
56    pub column_orders: Vec<ColumnOrder>,
57
58    /// Filter of aggregation.
59    pub filter: Option<Arc<dyn Expression>>,
60
61    /// Should deduplicate the input before aggregation.
62    pub distinct: bool,
63
64    /// Constant arguments.
65    pub direct_args: Vec<LiteralExpression>,
66}
67
68impl AggCall {
69    pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
70        let agg_type = AggType::from_protobuf_flatten(
71            agg_call.get_kind()?,
72            agg_call.udf.as_ref(),
73            agg_call.scalar.as_ref(),
74        )?;
75        let args = AggArgs::from_protobuf(agg_call.get_args())?;
76        let column_orders = agg_call
77            .get_order_by()
78            .iter()
79            .map(|col_order| {
80                let col_idx = col_order.get_column_index() as usize;
81                let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap());
82                ColumnOrder::new(col_idx, order_type)
83            })
84            .collect();
85        let filter = match agg_call.filter {
86            Some(ref pb_filter) => Some(build_from_prost(pb_filter)?.into()), /* TODO: non-strict filter in streaming */
87            None => None,
88        };
89        let direct_args = agg_call
90            .direct_args
91            .iter()
92            .map(|arg| {
93                let data_type = DataType::from(arg.get_type().unwrap());
94                LiteralExpression::new(
95                    data_type.clone(),
96                    Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(),
97                )
98            })
99            .collect_vec();
100        Ok(AggCall {
101            agg_type,
102            args,
103            return_type: DataType::from(agg_call.get_return_type()?),
104            column_orders,
105            filter,
106            distinct: agg_call.distinct,
107            direct_args,
108        })
109    }
110
111    /// Build an `AggCall` from a string.
112    ///
113    /// # Syntax
114    ///
115    /// ```text
116    /// (<name>:<type> [<index>:<type>]* [distinct] [orderby [<index>:<asc|desc>]*])
117    /// ```
118    pub fn from_pretty(s: impl AsRef<str>) -> Self {
119        let tokens = crate::expr::lexer(s.as_ref());
120        Parser::new(tokens.into_iter()).parse_aggregation()
121    }
122
123    pub fn with_filter(mut self, filter: BoxedExpression) -> Self {
124        self.filter = Some(filter.into());
125        self
126    }
127}
128
129struct Parser<Iter: Iterator> {
130    tokens: Peekable<Iter>,
131}
132
133impl<Iter: Iterator<Item = Token>> Parser<Iter> {
134    fn new(tokens: Iter) -> Self {
135        Self {
136            tokens: tokens.peekable(),
137        }
138    }
139
140    fn parse_aggregation(&mut self) -> AggCall {
141        assert_eq!(self.tokens.next(), Some(Token::LParen), "Expected a (");
142        let func = self.parse_function();
143        assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
144        let ty = self.parse_type();
145
146        let mut distinct = false;
147        let mut children = Vec::new();
148        let mut column_orders = Vec::new();
149        while matches!(self.tokens.peek(), Some(Token::Index(_))) {
150            children.push(self.parse_arg());
151        }
152        if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "distinct") {
153            distinct = true;
154            self.tokens.next(); // Consume
155        }
156        if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "orderby") {
157            self.tokens.next(); // Consume
158            while matches!(self.tokens.peek(), Some(Token::Index(_))) {
159                column_orders.push(self.parse_orderkey());
160            }
161        }
162        self.tokens.next(); // Consume the RParen
163
164        AggCall {
165            agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
166            args: AggArgs {
167                data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
168                val_indices: children.iter().map(|(idx, _)| *idx).collect(),
169            },
170            return_type: ty,
171            column_orders,
172            filter: None,
173            distinct,
174            direct_args: Vec::new(),
175        }
176    }
177
178    fn parse_type(&mut self) -> DataType {
179        match self.tokens.next().expect("Unexpected end of input") {
180            Token::Literal(name) => name.parse::<DataType>().expect_str("type", &name),
181            t => panic!("Expected a Literal, got {t:?}"),
182        }
183    }
184
185    fn parse_arg(&mut self) -> (usize, DataType) {
186        let idx = match self.tokens.next().expect("Unexpected end of input") {
187            Token::Index(idx) => idx,
188            t => panic!("Expected an Index, got {t:?}"),
189        };
190        assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
191        let ty = self.parse_type();
192        (idx, ty)
193    }
194
195    fn parse_function(&mut self) -> PbAggKind {
196        match self.tokens.next().expect("Unexpected end of input") {
197            Token::Literal(name) => {
198                PbAggKind::from_str_name(&name.to_uppercase()).expect_str("function", &name)
199            }
200            t => panic!("Expected a Literal, got {t:?}"),
201        }
202    }
203
204    fn parse_orderkey(&mut self) -> ColumnOrder {
205        let idx = match self.tokens.next().expect("Unexpected end of input") {
206            Token::Index(idx) => idx,
207            t => panic!("Expected an Index, got {t:?}"),
208        };
209        assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
210        let order = match self.tokens.next().expect("Unexpected end of input") {
211            Token::Literal(s) if s == "asc" => OrderType::ascending(),
212            Token::Literal(s) if s == "desc" => OrderType::descending(),
213            t => panic!("Expected asc or desc, got {t:?}"),
214        };
215        ColumnOrder::new(idx, order)
216    }
217}
218
219/// Aggregate function kind.
220#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumAsInner)]
221pub enum AggType {
222    /// Built-in aggregate function.
223    ///
224    /// The associated value should not be `UserDefined` or `WrapScalar`.
225    Builtin(PbAggKind),
226
227    /// User defined aggregate function.
228    UserDefined(PbUserDefinedFunctionMetadata),
229
230    /// Wrap a scalar function that takes a list as input as an aggregation function.
231    WrapScalar(PbExprNode),
232}
233
234impl Display for AggType {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        match self {
237            Self::Builtin(kind) => write!(f, "{}", kind.as_str_name().to_lowercase()),
238            Self::UserDefined(_) => write!(f, "udaf"),
239            Self::WrapScalar(_) => write!(f, "wrap_scalar"),
240        }
241    }
242}
243
244/// `FromStr` for builtin aggregate functions.
245impl FromStr for AggType {
246    type Err = ();
247
248    fn from_str(s: &str) -> Result<Self, Self::Err> {
249        let kind = PbAggKind::from_str(s)?;
250        Ok(AggType::Builtin(kind))
251    }
252}
253
254impl From<PbAggKind> for AggType {
255    fn from(pb: PbAggKind) -> Self {
256        assert!(!matches!(
257            pb,
258            PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar
259        ));
260        AggType::Builtin(pb)
261    }
262}
263
264impl AggType {
265    pub fn from_protobuf_flatten(
266        pb_kind: PbAggKind,
267        user_defined: Option<&PbUserDefinedFunctionMetadata>,
268        scalar: Option<&PbExprNode>,
269    ) -> Result<Self> {
270        match pb_kind {
271            PbAggKind::UserDefined => {
272                let user_defined = user_defined.context("expect user defined")?;
273                Ok(AggType::UserDefined(user_defined.clone()))
274            }
275            PbAggKind::WrapScalar => {
276                let scalar = scalar.context("expect scalar")?;
277                Ok(AggType::WrapScalar(scalar.clone()))
278            }
279            PbAggKind::Unspecified => bail!("Unrecognized agg."),
280            _ => Ok(AggType::Builtin(pb_kind)),
281        }
282    }
283
284    pub fn to_protobuf_simple(&self) -> PbAggKind {
285        match self {
286            Self::Builtin(pb) => *pb,
287            Self::UserDefined(_) => PbAggKind::UserDefined,
288            Self::WrapScalar(_) => PbAggKind::WrapScalar,
289        }
290    }
291
292    pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
293        match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
294            PbAggKind::Unspecified => bail!("Unrecognized agg."),
295            PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
296            PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
297            kind => Ok(AggType::Builtin(kind)),
298        }
299    }
300
301    pub fn to_protobuf(&self) -> PbAggType {
302        match self {
303            Self::Builtin(kind) => PbAggType {
304                kind: *kind as _,
305                udf_meta: None,
306                scalar_expr: None,
307            },
308            Self::UserDefined(udf_meta) => PbAggType {
309                kind: PbAggKind::UserDefined as _,
310                udf_meta: Some(udf_meta.clone()),
311                scalar_expr: None,
312            },
313            Self::WrapScalar(scalar_expr) => PbAggType {
314                kind: PbAggKind::WrapScalar as _,
315                udf_meta: None,
316                scalar_expr: Some(scalar_expr.clone()),
317            },
318        }
319    }
320}
321
322/// Macros to generate match arms for `AggType`.
323/// IMPORTANT: These macros must be carefully maintained especially when adding new
324/// `AggType`/`PbAggKind` variants.
325pub mod agg_types {
326    /// [`AggType`](super::AggType)s that are currently not supported in streaming mode.
327    #[macro_export]
328    macro_rules! unimplemented_in_stream {
329        () => {
330            AggType::Builtin(
331                PbAggKind::PercentileCont | PbAggKind::PercentileDisc | PbAggKind::Mode,
332            )
333        };
334    }
335    pub use unimplemented_in_stream;
336
337    /// [`AggType`](super::AggType)s that should've been rewritten to other kinds. These kinds
338    /// should not appear when generating physical plan nodes.
339    #[macro_export]
340    macro_rules! rewritten {
341        () => {
342            AggType::Builtin(
343                PbAggKind::Avg
344                    | PbAggKind::StddevPop
345                    | PbAggKind::StddevSamp
346                    | PbAggKind::VarPop
347                    | PbAggKind::VarSamp
348                    | PbAggKind::Grouping
349                    // ApproxPercentile always uses custom agg executors,
350                    // rather than an aggregation operator
351                    | PbAggKind::ApproxPercentile
352            )
353        };
354    }
355    pub use rewritten;
356
357    /// [`AggType`](super::AggType)s of which the aggregate results are not affected by the
358    /// user given ORDER BY clause.
359    #[macro_export]
360    macro_rules! result_unaffected_by_order_by {
361        () => {
362            AggType::Builtin(PbAggKind::BitAnd
363                | PbAggKind::BitOr
364                | PbAggKind::BitXor // XOR is commutative and associative
365                | PbAggKind::BoolAnd
366                | PbAggKind::BoolOr
367                | PbAggKind::Min
368                | PbAggKind::Max
369                | PbAggKind::Sum
370                | PbAggKind::Sum0
371                | PbAggKind::Count
372                | PbAggKind::Avg
373                | PbAggKind::ApproxCountDistinct
374                | PbAggKind::VarPop
375                | PbAggKind::VarSamp
376                | PbAggKind::StddevPop
377                | PbAggKind::StddevSamp)
378        };
379    }
380    pub use result_unaffected_by_order_by;
381
382    /// [`AggType`](super::AggType)s that must be called with ORDER BY clause. These are
383    /// slightly different from variants not in [`result_unaffected_by_order_by`], in that
384    /// variants returned by this macro should be banned while the others should just be warned.
385    #[macro_export]
386    macro_rules! must_have_order_by {
387        () => {
388            AggType::Builtin(
389                PbAggKind::FirstValue
390                    | PbAggKind::LastValue
391                    | PbAggKind::PercentileCont
392                    | PbAggKind::PercentileDisc
393                    | PbAggKind::Mode,
394            )
395        };
396    }
397    pub use must_have_order_by;
398
399    /// [`AggType`](super::AggType)s of which the aggregate results are not affected by the
400    /// user given DISTINCT keyword.
401    #[macro_export]
402    macro_rules! result_unaffected_by_distinct {
403        () => {
404            AggType::Builtin(
405                PbAggKind::BitAnd
406                    | PbAggKind::BitOr
407                    | PbAggKind::BoolAnd
408                    | PbAggKind::BoolOr
409                    | PbAggKind::Min
410                    | PbAggKind::Max
411                    | PbAggKind::ApproxCountDistinct,
412            )
413        };
414    }
415    pub use result_unaffected_by_distinct;
416
417    /// [`AggType`](crate::aggregate::AggType)s that are simply cannot 2-phased.
418    #[macro_export]
419    macro_rules! simply_cannot_two_phase {
420        () => {
421            AggType::Builtin(
422                PbAggKind::StringAgg
423                    | PbAggKind::ApproxCountDistinct
424                    | PbAggKind::ArrayAgg
425                    | PbAggKind::JsonbAgg
426                    | PbAggKind::JsonbObjectAgg
427                    | PbAggKind::FirstValue
428                    | PbAggKind::LastValue
429                    | PbAggKind::PercentileCont
430                    | PbAggKind::PercentileDisc
431                    | PbAggKind::Mode
432                    // FIXME(wrj): move `BoolAnd` and `BoolOr` out
433                    //  after we support general merge in stateless_simple_agg
434                    | PbAggKind::BoolAnd
435                    | PbAggKind::BoolOr
436                    | PbAggKind::BitAnd
437                    | PbAggKind::BitOr
438            )
439            | AggType::UserDefined(_)
440            | AggType::WrapScalar(_)
441        };
442    }
443    pub use simply_cannot_two_phase;
444
445    /// [`AggType`](super::AggType)s that are implemented with a single value state (so-called
446    /// stateless).
447    #[macro_export]
448    macro_rules! single_value_state {
449        () => {
450            AggType::Builtin(
451                PbAggKind::Sum
452                    | PbAggKind::Sum0
453                    | PbAggKind::Count
454                    | PbAggKind::BitAnd
455                    | PbAggKind::BitOr
456                    | PbAggKind::BitXor
457                    | PbAggKind::BoolAnd
458                    | PbAggKind::BoolOr
459                    | PbAggKind::ApproxCountDistinct
460                    | PbAggKind::InternalLastSeenValue
461                    | PbAggKind::ApproxPercentile,
462            ) | AggType::UserDefined(_)
463        };
464    }
465    pub use single_value_state;
466
467    /// [`AggType`](super::AggType)s that are implemented with a single value state (so-called
468    /// stateless) iff the input is append-only.
469    #[macro_export]
470    macro_rules! single_value_state_iff_in_append_only {
471        () => {
472            AggType::Builtin(PbAggKind::Max | PbAggKind::Min)
473        };
474    }
475    pub use single_value_state_iff_in_append_only;
476
477    /// [`AggType`](super::AggType)s that are implemented with a materialized input state.
478    #[macro_export]
479    macro_rules! materialized_input_state {
480        () => {
481            AggType::Builtin(
482                PbAggKind::Min
483                    | PbAggKind::Max
484                    | PbAggKind::FirstValue
485                    | PbAggKind::LastValue
486                    | PbAggKind::StringAgg
487                    | PbAggKind::ArrayAgg
488                    | PbAggKind::JsonbAgg
489                    | PbAggKind::JsonbObjectAgg,
490            ) | AggType::WrapScalar(_)
491        };
492    }
493    pub use materialized_input_state;
494
495    /// Ordered-set aggregate functions.
496    #[macro_export]
497    macro_rules! ordered_set {
498        () => {
499            AggType::Builtin(
500                PbAggKind::PercentileCont
501                    | PbAggKind::PercentileDisc
502                    | PbAggKind::Mode
503                    | PbAggKind::ApproxPercentile,
504            )
505        };
506    }
507    pub use ordered_set;
508}
509
510impl AggType {
511    /// Get the total phase agg kind from the partial phase agg kind.
512    pub fn partial_to_total(&self) -> Option<Self> {
513        match self {
514            AggType::Builtin(
515                PbAggKind::BitXor
516                | PbAggKind::Min
517                | PbAggKind::Max
518                | PbAggKind::Sum
519                | PbAggKind::InternalLastSeenValue,
520            ) => Some(self.clone()),
521            AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => {
522                Some(Self::Builtin(PbAggKind::Sum0))
523            }
524            agg_types::simply_cannot_two_phase!() => None,
525            agg_types::rewritten!() => None,
526            // invalid variants
527            AggType::Builtin(
528                PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
529            ) => None,
530        }
531    }
532}
533
534/// An aggregation function may accept 0, 1 or 2 arguments.
535#[derive(Clone, Debug, Default)]
536pub struct AggArgs {
537    data_types: Box<[DataType]>,
538    val_indices: Box<[usize]>,
539}
540
541impl AggArgs {
542    pub fn from_protobuf(args: &[PbInputRef]) -> Result<Self> {
543        Ok(AggArgs {
544            data_types: args
545                .iter()
546                .map(|arg| DataType::from(arg.get_type().unwrap()))
547                .collect(),
548            val_indices: args.iter().map(|arg| arg.get_index() as usize).collect(),
549        })
550    }
551
552    /// return the types of arguments.
553    pub fn arg_types(&self) -> &[DataType] {
554        &self.data_types
555    }
556
557    /// return the indices of the arguments in [`risingwave_common::array::StreamChunk`].
558    pub fn val_indices(&self) -> &[usize] {
559        &self.val_indices
560    }
561}
562
563impl FromIterator<(DataType, usize)> for AggArgs {
564    fn from_iter<T: IntoIterator<Item = (DataType, usize)>>(iter: T) -> Self {
565        let (data_types, val_indices): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
566        AggArgs {
567            data_types: data_types.into(),
568            val_indices: val_indices.into(),
569        }
570    }
571}