risingwave_expr/aggregate/
def.rs1use std::fmt::Display;
18use std::iter::Peekable;
19use std::str::FromStr;
20use std::sync::Arc;
21
22use anyhow::Context;
23use enum_as_inner::EnumAsInner;
24use itertools::Itertools;
25use risingwave_common::bail;
26use risingwave_common::types::{DataType, Datum};
27use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
28use risingwave_common::util::value_encoding::DatumFromProtoExt;
29pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
30use risingwave_pb::expr::{
31 PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
32};
33
34use crate::Result;
35use crate::expr::{
36 BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, build_from_prost,
37};
38
39#[derive(Debug, Clone)]
45pub struct AggCall {
46 pub agg_type: AggType,
48
49 pub args: AggArgs,
51
52 pub return_type: DataType,
54
55 pub column_orders: Vec<ColumnOrder>,
57
58 pub filter: Option<Arc<dyn Expression>>,
60
61 pub distinct: bool,
63
64 pub direct_args: Vec<LiteralExpression>,
66}
67
68impl AggCall {
69 pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
70 let agg_type = AggType::from_protobuf_flatten(
71 agg_call.get_kind()?,
72 agg_call.udf.as_ref(),
73 agg_call.scalar.as_ref(),
74 )?;
75 let args = AggArgs::from_protobuf(agg_call.get_args())?;
76 let column_orders = agg_call
77 .get_order_by()
78 .iter()
79 .map(|col_order| {
80 let col_idx = col_order.get_column_index() as usize;
81 let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap());
82 ColumnOrder::new(col_idx, order_type)
83 })
84 .collect();
85 let filter = match agg_call.filter {
86 Some(ref pb_filter) => Some(build_from_prost(pb_filter)?.into()), None => None,
88 };
89 let direct_args = agg_call
90 .direct_args
91 .iter()
92 .map(|arg| {
93 let data_type = DataType::from(arg.get_type().unwrap());
94 LiteralExpression::new(
95 data_type.clone(),
96 Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(),
97 )
98 })
99 .collect_vec();
100 Ok(AggCall {
101 agg_type,
102 args,
103 return_type: DataType::from(agg_call.get_return_type()?),
104 column_orders,
105 filter,
106 distinct: agg_call.distinct,
107 direct_args,
108 })
109 }
110
111 pub fn from_pretty(s: impl AsRef<str>) -> Self {
119 let tokens = crate::expr::lexer(s.as_ref());
120 Parser::new(tokens.into_iter()).parse_aggregation()
121 }
122
123 pub fn with_filter(mut self, filter: BoxedExpression) -> Self {
124 self.filter = Some(filter.into());
125 self
126 }
127}
128
129struct Parser<Iter: Iterator> {
130 tokens: Peekable<Iter>,
131}
132
133impl<Iter: Iterator<Item = Token>> Parser<Iter> {
134 fn new(tokens: Iter) -> Self {
135 Self {
136 tokens: tokens.peekable(),
137 }
138 }
139
140 fn parse_aggregation(&mut self) -> AggCall {
141 assert_eq!(self.tokens.next(), Some(Token::LParen), "Expected a (");
142 let func = self.parse_function();
143 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
144 let ty = self.parse_type();
145
146 let mut distinct = false;
147 let mut children = Vec::new();
148 let mut column_orders = Vec::new();
149 while matches!(self.tokens.peek(), Some(Token::Index(_))) {
150 children.push(self.parse_arg());
151 }
152 if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "distinct") {
153 distinct = true;
154 self.tokens.next(); }
156 if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "orderby") {
157 self.tokens.next(); while matches!(self.tokens.peek(), Some(Token::Index(_))) {
159 column_orders.push(self.parse_orderkey());
160 }
161 }
162 self.tokens.next(); AggCall {
165 agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
166 args: AggArgs {
167 data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
168 val_indices: children.iter().map(|(idx, _)| *idx).collect(),
169 },
170 return_type: ty,
171 column_orders,
172 filter: None,
173 distinct,
174 direct_args: Vec::new(),
175 }
176 }
177
178 fn parse_type(&mut self) -> DataType {
179 match self.tokens.next().expect("Unexpected end of input") {
180 Token::Literal(name) => name.parse::<DataType>().expect_str("type", &name),
181 t => panic!("Expected a Literal, got {t:?}"),
182 }
183 }
184
185 fn parse_arg(&mut self) -> (usize, DataType) {
186 let idx = match self.tokens.next().expect("Unexpected end of input") {
187 Token::Index(idx) => idx,
188 t => panic!("Expected an Index, got {t:?}"),
189 };
190 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
191 let ty = self.parse_type();
192 (idx, ty)
193 }
194
195 fn parse_function(&mut self) -> PbAggKind {
196 match self.tokens.next().expect("Unexpected end of input") {
197 Token::Literal(name) => {
198 PbAggKind::from_str_name(&name.to_uppercase()).expect_str("function", &name)
199 }
200 t => panic!("Expected a Literal, got {t:?}"),
201 }
202 }
203
204 fn parse_orderkey(&mut self) -> ColumnOrder {
205 let idx = match self.tokens.next().expect("Unexpected end of input") {
206 Token::Index(idx) => idx,
207 t => panic!("Expected an Index, got {t:?}"),
208 };
209 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
210 let order = match self.tokens.next().expect("Unexpected end of input") {
211 Token::Literal(s) if s == "asc" => OrderType::ascending(),
212 Token::Literal(s) if s == "desc" => OrderType::descending(),
213 t => panic!("Expected asc or desc, got {t:?}"),
214 };
215 ColumnOrder::new(idx, order)
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumAsInner)]
221pub enum AggType {
222 Builtin(PbAggKind),
226
227 UserDefined(PbUserDefinedFunctionMetadata),
229
230 WrapScalar(PbExprNode),
232}
233
234impl Display for AggType {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 Self::Builtin(kind) => write!(f, "{}", kind.as_str_name().to_lowercase()),
238 Self::UserDefined(_) => write!(f, "udaf"),
239 Self::WrapScalar(_) => write!(f, "wrap_scalar"),
240 }
241 }
242}
243
244impl FromStr for AggType {
246 type Err = ();
247
248 fn from_str(s: &str) -> Result<Self, Self::Err> {
249 let kind = PbAggKind::from_str(s)?;
250 Ok(AggType::Builtin(kind))
251 }
252}
253
254impl From<PbAggKind> for AggType {
255 fn from(pb: PbAggKind) -> Self {
256 assert!(!matches!(
257 pb,
258 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar
259 ));
260 AggType::Builtin(pb)
261 }
262}
263
264impl AggType {
265 pub fn from_protobuf_flatten(
266 pb_kind: PbAggKind,
267 user_defined: Option<&PbUserDefinedFunctionMetadata>,
268 scalar: Option<&PbExprNode>,
269 ) -> Result<Self> {
270 match pb_kind {
271 PbAggKind::UserDefined => {
272 let user_defined = user_defined.context("expect user defined")?;
273 Ok(AggType::UserDefined(user_defined.clone()))
274 }
275 PbAggKind::WrapScalar => {
276 let scalar = scalar.context("expect scalar")?;
277 Ok(AggType::WrapScalar(scalar.clone()))
278 }
279 PbAggKind::Unspecified => bail!("Unrecognized agg."),
280 _ => Ok(AggType::Builtin(pb_kind)),
281 }
282 }
283
284 pub fn to_protobuf_simple(&self) -> PbAggKind {
285 match self {
286 Self::Builtin(pb) => *pb,
287 Self::UserDefined(_) => PbAggKind::UserDefined,
288 Self::WrapScalar(_) => PbAggKind::WrapScalar,
289 }
290 }
291
292 pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
293 match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
294 PbAggKind::Unspecified => bail!("Unrecognized agg."),
295 PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
296 PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
297 kind => Ok(AggType::Builtin(kind)),
298 }
299 }
300
301 pub fn to_protobuf(&self) -> PbAggType {
302 match self {
303 Self::Builtin(kind) => PbAggType {
304 kind: *kind as _,
305 udf_meta: None,
306 scalar_expr: None,
307 },
308 Self::UserDefined(udf_meta) => PbAggType {
309 kind: PbAggKind::UserDefined as _,
310 udf_meta: Some(udf_meta.clone()),
311 scalar_expr: None,
312 },
313 Self::WrapScalar(scalar_expr) => PbAggType {
314 kind: PbAggKind::WrapScalar as _,
315 udf_meta: None,
316 scalar_expr: Some(scalar_expr.clone()),
317 },
318 }
319 }
320}
321
322pub mod agg_types {
326 #[macro_export]
328 macro_rules! unimplemented_in_stream {
329 () => {
330 AggType::Builtin(
331 PbAggKind::PercentileCont | PbAggKind::PercentileDisc | PbAggKind::Mode,
332 )
333 };
334 }
335 pub use unimplemented_in_stream;
336
337 #[macro_export]
340 macro_rules! rewritten {
341 () => {
342 AggType::Builtin(
343 PbAggKind::Avg
344 | PbAggKind::StddevPop
345 | PbAggKind::StddevSamp
346 | PbAggKind::VarPop
347 | PbAggKind::VarSamp
348 | PbAggKind::Grouping
349 | PbAggKind::ApproxPercentile
352 | PbAggKind::ArgMin
353 | PbAggKind::ArgMax
354 )
355 };
356 }
357 pub use rewritten;
358
359 #[macro_export]
362 macro_rules! result_unaffected_by_order_by {
363 () => {
364 AggType::Builtin(PbAggKind::BitAnd
365 | PbAggKind::BitOr
366 | PbAggKind::BitXor | PbAggKind::BoolAnd
368 | PbAggKind::BoolOr
369 | PbAggKind::Min
370 | PbAggKind::Max
371 | PbAggKind::Sum
372 | PbAggKind::Sum0
373 | PbAggKind::Count
374 | PbAggKind::Avg
375 | PbAggKind::ApproxCountDistinct
376 | PbAggKind::VarPop
377 | PbAggKind::VarSamp
378 | PbAggKind::StddevPop
379 | PbAggKind::StddevSamp)
380 };
381 }
382 pub use result_unaffected_by_order_by;
383
384 #[macro_export]
388 macro_rules! must_have_order_by {
389 () => {
390 AggType::Builtin(
391 PbAggKind::FirstValue
392 | PbAggKind::LastValue
393 | PbAggKind::PercentileCont
394 | PbAggKind::PercentileDisc
395 | PbAggKind::Mode,
396 )
397 };
398 }
399 pub use must_have_order_by;
400
401 #[macro_export]
404 macro_rules! result_unaffected_by_distinct {
405 () => {
406 AggType::Builtin(
407 PbAggKind::BitAnd
408 | PbAggKind::BitOr
409 | PbAggKind::BoolAnd
410 | PbAggKind::BoolOr
411 | PbAggKind::Min
412 | PbAggKind::Max
413 | PbAggKind::ApproxCountDistinct,
414 )
415 };
416 }
417 pub use result_unaffected_by_distinct;
418
419 #[macro_export]
421 macro_rules! simply_cannot_two_phase {
422 () => {
423 AggType::Builtin(
424 PbAggKind::StringAgg
425 | PbAggKind::ApproxCountDistinct
426 | PbAggKind::ArrayAgg
427 | PbAggKind::JsonbAgg
428 | PbAggKind::JsonbObjectAgg
429 | PbAggKind::FirstValue
430 | PbAggKind::LastValue
431 | PbAggKind::PercentileCont
432 | PbAggKind::PercentileDisc
433 | PbAggKind::Mode
434 | PbAggKind::BoolAnd
437 | PbAggKind::BoolOr
438 | PbAggKind::BitAnd
439 | PbAggKind::BitOr
440 )
441 | AggType::UserDefined(_)
442 | AggType::WrapScalar(_)
443 };
444 }
445 pub use simply_cannot_two_phase;
446
447 #[macro_export]
450 macro_rules! single_value_state {
451 () => {
452 AggType::Builtin(
453 PbAggKind::Sum
454 | PbAggKind::Sum0
455 | PbAggKind::Count
456 | PbAggKind::BitAnd
457 | PbAggKind::BitOr
458 | PbAggKind::BitXor
459 | PbAggKind::BoolAnd
460 | PbAggKind::BoolOr
461 | PbAggKind::ApproxCountDistinct
462 | PbAggKind::InternalLastSeenValue
463 | PbAggKind::ApproxPercentile,
464 ) | AggType::UserDefined(_)
465 };
466 }
467 pub use single_value_state;
468
469 #[macro_export]
472 macro_rules! single_value_state_iff_in_append_only {
473 () => {
474 AggType::Builtin(PbAggKind::Max | PbAggKind::Min)
475 };
476 }
477 pub use single_value_state_iff_in_append_only;
478
479 #[macro_export]
481 macro_rules! materialized_input_state {
482 () => {
483 AggType::Builtin(
484 PbAggKind::Min
485 | PbAggKind::Max
486 | PbAggKind::FirstValue
487 | PbAggKind::LastValue
488 | PbAggKind::StringAgg
489 | PbAggKind::ArrayAgg
490 | PbAggKind::JsonbAgg
491 | PbAggKind::JsonbObjectAgg,
492 ) | AggType::WrapScalar(_)
493 };
494 }
495 pub use materialized_input_state;
496
497 #[macro_export]
499 macro_rules! ordered_set {
500 () => {
501 AggType::Builtin(
502 PbAggKind::PercentileCont
503 | PbAggKind::PercentileDisc
504 | PbAggKind::Mode
505 | PbAggKind::ApproxPercentile,
506 )
507 };
508 }
509 pub use ordered_set;
510}
511
512impl AggType {
513 pub fn partial_to_total(&self) -> Option<Self> {
515 match self {
516 AggType::Builtin(
517 PbAggKind::BitXor
518 | PbAggKind::Min
519 | PbAggKind::Max
520 | PbAggKind::Sum
521 | PbAggKind::InternalLastSeenValue,
522 ) => Some(self.clone()),
523 AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => {
524 Some(Self::Builtin(PbAggKind::Sum0))
525 }
526 agg_types::simply_cannot_two_phase!() => None,
527 agg_types::rewritten!() => None,
528 AggType::Builtin(
530 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
531 ) => None,
532 }
533 }
534}
535
536#[derive(Clone, Debug, Default)]
538pub struct AggArgs {
539 data_types: Box<[DataType]>,
540 val_indices: Box<[usize]>,
541}
542
543impl AggArgs {
544 pub fn from_protobuf(args: &[PbInputRef]) -> Result<Self> {
545 Ok(AggArgs {
546 data_types: args
547 .iter()
548 .map(|arg| DataType::from(arg.get_type().unwrap()))
549 .collect(),
550 val_indices: args.iter().map(|arg| arg.get_index() as usize).collect(),
551 })
552 }
553
554 pub fn arg_types(&self) -> &[DataType] {
556 &self.data_types
557 }
558
559 pub fn val_indices(&self) -> &[usize] {
561 &self.val_indices
562 }
563}
564
565impl FromIterator<(DataType, usize)> for AggArgs {
566 fn from_iter<T: IntoIterator<Item = (DataType, usize)>>(iter: T) -> Self {
567 let (data_types, val_indices): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
568 AggArgs {
569 data_types: data_types.into(),
570 val_indices: val_indices.into(),
571 }
572 }
573}