1use std::fmt::Display;
18use std::iter::Peekable;
19use std::str::FromStr;
20use std::sync::Arc;
21
22use anyhow::Context;
23use enum_as_inner::EnumAsInner;
24use itertools::Itertools;
25use risingwave_common::bail;
26use risingwave_common::types::{DataType, Datum};
27use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
28use risingwave_common::util::value_encoding::DatumFromProtoExt;
29pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
30use risingwave_pb::expr::{
31 PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
32};
33
34use crate::Result;
35use crate::expr::{
36 BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, build_from_prost,
37};
38
39#[derive(Debug, Clone)]
45pub struct AggCall {
46 pub agg_type: AggType,
48
49 pub args: AggArgs,
51
52 pub return_type: DataType,
54
55 pub column_orders: Vec<ColumnOrder>,
57
58 pub filter: Option<Arc<dyn Expression>>,
60
61 pub distinct: bool,
63
64 pub direct_args: Vec<LiteralExpression>,
66}
67
68impl AggCall {
69 pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
70 let agg_type = AggType::from_protobuf_flatten(
71 agg_call.get_kind()?,
72 agg_call.udf.as_ref(),
73 agg_call.scalar.as_ref(),
74 )?;
75 let args = AggArgs::from_protobuf(agg_call.get_args())?;
76 let column_orders = agg_call
77 .get_order_by()
78 .iter()
79 .map(|col_order| {
80 let col_idx = col_order.get_column_index() as usize;
81 let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap());
82 ColumnOrder::new(col_idx, order_type)
83 })
84 .collect();
85 let filter = match agg_call.filter {
86 Some(ref pb_filter) => Some(build_from_prost(pb_filter)?.into()), None => None,
88 };
89 let direct_args = agg_call
90 .direct_args
91 .iter()
92 .map(|arg| {
93 let data_type = DataType::from(arg.get_type().unwrap());
94 LiteralExpression::new(
95 data_type.clone(),
96 Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(),
97 )
98 })
99 .collect_vec();
100 Ok(AggCall {
101 agg_type,
102 args,
103 return_type: DataType::from(agg_call.get_return_type()?),
104 column_orders,
105 filter,
106 distinct: agg_call.distinct,
107 direct_args,
108 })
109 }
110
111 pub fn from_pretty(s: impl AsRef<str>) -> Self {
119 let tokens = crate::expr::lexer(s.as_ref());
120 Parser::new(tokens.into_iter()).parse_aggregation()
121 }
122
123 pub fn with_filter(mut self, filter: BoxedExpression) -> Self {
124 self.filter = Some(filter.into());
125 self
126 }
127}
128
129struct Parser<Iter: Iterator> {
130 tokens: Peekable<Iter>,
131}
132
133impl<Iter: Iterator<Item = Token>> Parser<Iter> {
134 fn new(tokens: Iter) -> Self {
135 Self {
136 tokens: tokens.peekable(),
137 }
138 }
139
140 fn parse_aggregation(&mut self) -> AggCall {
141 assert_eq!(self.tokens.next(), Some(Token::LParen), "Expected a (");
142 let func = self.parse_function();
143 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
144 let ty = self.parse_type();
145
146 let mut distinct = false;
147 let mut children = Vec::new();
148 let mut column_orders = Vec::new();
149 while matches!(self.tokens.peek(), Some(Token::Index(_))) {
150 children.push(self.parse_arg());
151 }
152 if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "distinct") {
153 distinct = true;
154 self.tokens.next(); }
156 if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "orderby") {
157 self.tokens.next(); while matches!(self.tokens.peek(), Some(Token::Index(_))) {
159 column_orders.push(self.parse_orderkey());
160 }
161 }
162 self.tokens.next(); AggCall {
165 agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
166 args: AggArgs {
167 data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
168 val_indices: children.iter().map(|(idx, _)| *idx).collect(),
169 },
170 return_type: ty,
171 column_orders,
172 filter: None,
173 distinct,
174 direct_args: Vec::new(),
175 }
176 }
177
178 fn parse_type(&mut self) -> DataType {
179 match self.tokens.next().expect("Unexpected end of input") {
180 Token::Literal(name) => name.parse::<DataType>().expect_str("type", &name),
181 t => panic!("Expected a Literal, got {t:?}"),
182 }
183 }
184
185 fn parse_arg(&mut self) -> (usize, DataType) {
186 let idx = match self.tokens.next().expect("Unexpected end of input") {
187 Token::Index(idx) => idx,
188 t => panic!("Expected an Index, got {t:?}"),
189 };
190 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
191 let ty = self.parse_type();
192 (idx, ty)
193 }
194
195 fn parse_function(&mut self) -> PbAggKind {
196 match self.tokens.next().expect("Unexpected end of input") {
197 Token::Literal(name) => {
198 PbAggKind::from_str_name(&name.to_uppercase()).expect_str("function", &name)
199 }
200 t => panic!("Expected a Literal, got {t:?}"),
201 }
202 }
203
204 fn parse_orderkey(&mut self) -> ColumnOrder {
205 let idx = match self.tokens.next().expect("Unexpected end of input") {
206 Token::Index(idx) => idx,
207 t => panic!("Expected an Index, got {t:?}"),
208 };
209 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
210 let order = match self.tokens.next().expect("Unexpected end of input") {
211 Token::Literal(s) if s == "asc" => OrderType::ascending(),
212 Token::Literal(s) if s == "desc" => OrderType::descending(),
213 t => panic!("Expected asc or desc, got {t:?}"),
214 };
215 ColumnOrder::new(idx, order)
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumAsInner)]
221pub enum AggType {
222 Builtin(PbAggKind),
226
227 UserDefined(PbUserDefinedFunctionMetadata),
229
230 WrapScalar(PbExprNode),
232}
233
234impl Display for AggType {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 Self::Builtin(kind) => write!(f, "{}", kind.as_str_name().to_lowercase()),
238 Self::UserDefined(_) => write!(f, "udaf"),
239 Self::WrapScalar(_) => write!(f, "wrap_scalar"),
240 }
241 }
242}
243
244impl FromStr for AggType {
246 type Err = ();
247
248 fn from_str(s: &str) -> Result<Self, Self::Err> {
249 let kind = PbAggKind::from_str(s)?;
250 Ok(AggType::Builtin(kind))
251 }
252}
253
254impl From<PbAggKind> for AggType {
255 fn from(pb: PbAggKind) -> Self {
256 assert!(!matches!(
257 pb,
258 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar
259 ));
260 AggType::Builtin(pb)
261 }
262}
263
264impl AggType {
265 pub fn from_protobuf_flatten(
266 pb_kind: PbAggKind,
267 user_defined: Option<&PbUserDefinedFunctionMetadata>,
268 scalar: Option<&PbExprNode>,
269 ) -> Result<Self> {
270 match pb_kind {
271 PbAggKind::UserDefined => {
272 let user_defined = user_defined.context("expect user defined")?;
273 Ok(AggType::UserDefined(user_defined.clone()))
274 }
275 PbAggKind::WrapScalar => {
276 let scalar = scalar.context("expect scalar")?;
277 Ok(AggType::WrapScalar(scalar.clone()))
278 }
279 PbAggKind::Unspecified => bail!("Unrecognized agg."),
280 _ => Ok(AggType::Builtin(pb_kind)),
281 }
282 }
283
284 pub fn to_protobuf_simple(&self) -> PbAggKind {
285 match self {
286 Self::Builtin(pb) => *pb,
287 Self::UserDefined(_) => PbAggKind::UserDefined,
288 Self::WrapScalar(_) => PbAggKind::WrapScalar,
289 }
290 }
291
292 pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
293 match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
294 PbAggKind::Unspecified => bail!("Unrecognized agg."),
295 PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
296 PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
297 kind => Ok(AggType::Builtin(kind)),
298 }
299 }
300
301 pub fn to_protobuf(&self) -> PbAggType {
302 match self {
303 Self::Builtin(kind) => PbAggType {
304 kind: *kind as _,
305 udf_meta: None,
306 scalar_expr: None,
307 },
308 Self::UserDefined(udf_meta) => PbAggType {
309 kind: PbAggKind::UserDefined as _,
310 udf_meta: Some(udf_meta.clone()),
311 scalar_expr: None,
312 },
313 Self::WrapScalar(scalar_expr) => PbAggType {
314 kind: PbAggKind::WrapScalar as _,
315 udf_meta: None,
316 scalar_expr: Some(scalar_expr.clone()),
317 },
318 }
319 }
320}
321
322pub mod agg_types {
326 #[macro_export]
328 macro_rules! unimplemented_in_stream {
329 () => {
330 AggType::Builtin(
331 PbAggKind::PercentileCont | PbAggKind::PercentileDisc | PbAggKind::Mode,
332 )
333 };
334 }
335 pub use unimplemented_in_stream;
336
337 #[macro_export]
340 macro_rules! rewritten {
341 () => {
342 AggType::Builtin(
343 PbAggKind::Avg
344 | PbAggKind::StddevPop
345 | PbAggKind::StddevSamp
346 | PbAggKind::VarPop
347 | PbAggKind::VarSamp
348 | PbAggKind::Grouping
349 | PbAggKind::ApproxPercentile
352 )
353 };
354 }
355 pub use rewritten;
356
357 #[macro_export]
360 macro_rules! result_unaffected_by_order_by {
361 () => {
362 AggType::Builtin(PbAggKind::BitAnd
363 | PbAggKind::BitOr
364 | PbAggKind::BitXor | PbAggKind::BoolAnd
366 | PbAggKind::BoolOr
367 | PbAggKind::Min
368 | PbAggKind::Max
369 | PbAggKind::Sum
370 | PbAggKind::Sum0
371 | PbAggKind::Count
372 | PbAggKind::Avg
373 | PbAggKind::ApproxCountDistinct
374 | PbAggKind::VarPop
375 | PbAggKind::VarSamp
376 | PbAggKind::StddevPop
377 | PbAggKind::StddevSamp)
378 };
379 }
380 pub use result_unaffected_by_order_by;
381
382 #[macro_export]
386 macro_rules! must_have_order_by {
387 () => {
388 AggType::Builtin(
389 PbAggKind::FirstValue
390 | PbAggKind::LastValue
391 | PbAggKind::PercentileCont
392 | PbAggKind::PercentileDisc
393 | PbAggKind::Mode,
394 )
395 };
396 }
397 pub use must_have_order_by;
398
399 #[macro_export]
402 macro_rules! result_unaffected_by_distinct {
403 () => {
404 AggType::Builtin(
405 PbAggKind::BitAnd
406 | PbAggKind::BitOr
407 | PbAggKind::BoolAnd
408 | PbAggKind::BoolOr
409 | PbAggKind::Min
410 | PbAggKind::Max
411 | PbAggKind::ApproxCountDistinct,
412 )
413 };
414 }
415 pub use result_unaffected_by_distinct;
416
417 #[macro_export]
419 macro_rules! simply_cannot_two_phase {
420 () => {
421 AggType::Builtin(
422 PbAggKind::StringAgg
423 | PbAggKind::ApproxCountDistinct
424 | PbAggKind::ArrayAgg
425 | PbAggKind::JsonbAgg
426 | PbAggKind::JsonbObjectAgg
427 | PbAggKind::FirstValue
428 | PbAggKind::LastValue
429 | PbAggKind::PercentileCont
430 | PbAggKind::PercentileDisc
431 | PbAggKind::Mode
432 | PbAggKind::BoolAnd
435 | PbAggKind::BoolOr
436 | PbAggKind::BitAnd
437 | PbAggKind::BitOr
438 )
439 | AggType::UserDefined(_)
440 | AggType::WrapScalar(_)
441 };
442 }
443 pub use simply_cannot_two_phase;
444
445 #[macro_export]
448 macro_rules! single_value_state {
449 () => {
450 AggType::Builtin(
451 PbAggKind::Sum
452 | PbAggKind::Sum0
453 | PbAggKind::Count
454 | PbAggKind::BitAnd
455 | PbAggKind::BitOr
456 | PbAggKind::BitXor
457 | PbAggKind::BoolAnd
458 | PbAggKind::BoolOr
459 | PbAggKind::ApproxCountDistinct
460 | PbAggKind::InternalLastSeenValue
461 | PbAggKind::ApproxPercentile,
462 ) | AggType::UserDefined(_)
463 };
464 }
465 pub use single_value_state;
466
467 #[macro_export]
470 macro_rules! single_value_state_iff_in_append_only {
471 () => {
472 AggType::Builtin(PbAggKind::Max | PbAggKind::Min)
473 };
474 }
475 pub use single_value_state_iff_in_append_only;
476
477 #[macro_export]
479 macro_rules! materialized_input_state {
480 () => {
481 AggType::Builtin(
482 PbAggKind::Min
483 | PbAggKind::Max
484 | PbAggKind::FirstValue
485 | PbAggKind::LastValue
486 | PbAggKind::StringAgg
487 | PbAggKind::ArrayAgg
488 | PbAggKind::JsonbAgg
489 | PbAggKind::JsonbObjectAgg,
490 ) | AggType::WrapScalar(_)
491 };
492 }
493 pub use materialized_input_state;
494
495 #[macro_export]
497 macro_rules! ordered_set {
498 () => {
499 AggType::Builtin(
500 PbAggKind::PercentileCont
501 | PbAggKind::PercentileDisc
502 | PbAggKind::Mode
503 | PbAggKind::ApproxPercentile,
504 )
505 };
506 }
507 pub use ordered_set;
508}
509
510impl AggType {
511 pub fn partial_to_total(&self) -> Option<Self> {
513 match self {
514 AggType::Builtin(
515 PbAggKind::BitXor
516 | PbAggKind::Min
517 | PbAggKind::Max
518 | PbAggKind::Sum
519 | PbAggKind::InternalLastSeenValue,
520 ) => Some(self.clone()),
521 AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => {
522 Some(Self::Builtin(PbAggKind::Sum0))
523 }
524 agg_types::simply_cannot_two_phase!() => None,
525 agg_types::rewritten!() => None,
526 AggType::Builtin(
528 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
529 ) => None,
530 }
531 }
532}
533
534#[derive(Clone, Debug, Default)]
536pub struct AggArgs {
537 data_types: Box<[DataType]>,
538 val_indices: Box<[usize]>,
539}
540
541impl AggArgs {
542 pub fn from_protobuf(args: &[PbInputRef]) -> Result<Self> {
543 Ok(AggArgs {
544 data_types: args
545 .iter()
546 .map(|arg| DataType::from(arg.get_type().unwrap()))
547 .collect(),
548 val_indices: args.iter().map(|arg| arg.get_index() as usize).collect(),
549 })
550 }
551
552 pub fn arg_types(&self) -> &[DataType] {
554 &self.data_types
555 }
556
557 pub fn val_indices(&self) -> &[usize] {
559 &self.val_indices
560 }
561}
562
563impl FromIterator<(DataType, usize)> for AggArgs {
564 fn from_iter<T: IntoIterator<Item = (DataType, usize)>>(iter: T) -> Self {
565 let (data_types, val_indices): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
566 AggArgs {
567 data_types: data_types.into(),
568 val_indices: val_indices.into(),
569 }
570 }
571}