risingwave_expr/aggregate/
def.rs1use std::fmt::Display;
18use std::iter::Peekable;
19use std::str::FromStr;
20use std::sync::Arc;
21
22use anyhow::Context;
23use enum_as_inner::EnumAsInner;
24use itertools::Itertools;
25use risingwave_common::bail;
26use risingwave_common::types::{DataType, Datum};
27use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
28use risingwave_common::util::value_encoding::DatumFromProtoExt;
29pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
30use risingwave_pb::expr::{
31 PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
32};
33
34use crate::Result;
35use crate::expr::{
36 BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, build_from_prost,
37};
38
39#[derive(Debug, Clone)]
45pub struct AggCall {
46 pub agg_type: AggType,
48
49 pub args: AggArgs,
51
52 pub return_type: DataType,
54
55 pub column_orders: Vec<ColumnOrder>,
57
58 pub filter: Option<Arc<dyn Expression>>,
60
61 pub distinct: bool,
63
64 pub direct_args: Vec<LiteralExpression>,
66}
67
68impl AggCall {
69 pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
70 let agg_type = AggType::from_protobuf_flatten(
71 agg_call.get_kind()?,
72 agg_call.udf.as_ref(),
73 agg_call.scalar.as_ref(),
74 )?;
75 let args = AggArgs::from_protobuf(agg_call.get_args())?;
76 let column_orders = agg_call
77 .get_order_by()
78 .iter()
79 .map(|col_order| {
80 let col_idx = col_order.get_column_index() as usize;
81 let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap());
82 ColumnOrder::new(col_idx, order_type)
83 })
84 .collect();
85 let filter = match agg_call.filter {
86 Some(ref pb_filter) => Some(build_from_prost(pb_filter)?.into()), None => None,
88 };
89 let direct_args = agg_call
90 .direct_args
91 .iter()
92 .map(|arg| {
93 let data_type = DataType::from(arg.get_type().unwrap());
94 LiteralExpression::new(
95 data_type.clone(),
96 Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(),
97 )
98 })
99 .collect_vec();
100 Ok(AggCall {
101 agg_type,
102 args,
103 return_type: DataType::from(agg_call.get_return_type()?),
104 column_orders,
105 filter,
106 distinct: agg_call.distinct,
107 direct_args,
108 })
109 }
110
111 pub fn from_pretty(s: impl AsRef<str>) -> Self {
119 let tokens = crate::expr::lexer(s.as_ref());
120 Parser::new(tokens.into_iter()).parse_aggregation()
121 }
122
123 pub fn with_filter(mut self, filter: BoxedExpression) -> Self {
124 self.filter = Some(filter.into());
125 self
126 }
127}
128
129struct Parser<Iter: Iterator> {
130 tokens: Peekable<Iter>,
131}
132
133impl<Iter: Iterator<Item = Token>> Parser<Iter> {
134 fn new(tokens: Iter) -> Self {
135 Self {
136 tokens: tokens.peekable(),
137 }
138 }
139
140 fn parse_aggregation(&mut self) -> AggCall {
141 assert_eq!(self.tokens.next(), Some(Token::LParen), "Expected a (");
142 let func = self.parse_function();
143 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
144 let ty = self.parse_type();
145
146 let mut distinct = false;
147 let mut children = Vec::new();
148 let mut column_orders = Vec::new();
149 while matches!(self.tokens.peek(), Some(Token::Index(_))) {
150 children.push(self.parse_arg());
151 }
152 if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "distinct") {
153 distinct = true;
154 self.tokens.next(); }
156 if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "orderby") {
157 self.tokens.next(); while matches!(self.tokens.peek(), Some(Token::Index(_))) {
159 column_orders.push(self.parse_orderkey());
160 }
161 }
162 self.tokens.next(); AggCall {
165 agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
166 args: AggArgs {
167 data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
168 val_indices: children.iter().map(|(idx, _)| *idx).collect(),
169 },
170 return_type: ty,
171 column_orders,
172 filter: None,
173 distinct,
174 direct_args: Vec::new(),
175 }
176 }
177
178 fn parse_type(&mut self) -> DataType {
179 match self.tokens.next().expect("Unexpected end of input") {
180 Token::Literal(name) => name.parse::<DataType>().expect_str("type", &name),
181 t => panic!("Expected a Literal, got {t:?}"),
182 }
183 }
184
185 fn parse_arg(&mut self) -> (usize, DataType) {
186 let idx = match self.tokens.next().expect("Unexpected end of input") {
187 Token::Index(idx) => idx,
188 t => panic!("Expected an Index, got {t:?}"),
189 };
190 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
191 let ty = self.parse_type();
192 (idx, ty)
193 }
194
195 fn parse_function(&mut self) -> PbAggKind {
196 match self.tokens.next().expect("Unexpected end of input") {
197 Token::Literal(name) => {
198 PbAggKind::from_str_name(&name.to_uppercase()).expect_str("function", &name)
199 }
200 t => panic!("Expected a Literal, got {t:?}"),
201 }
202 }
203
204 fn parse_orderkey(&mut self) -> ColumnOrder {
205 let idx = match self.tokens.next().expect("Unexpected end of input") {
206 Token::Index(idx) => idx,
207 t => panic!("Expected an Index, got {t:?}"),
208 };
209 assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
210 let order = match self.tokens.next().expect("Unexpected end of input") {
211 Token::Literal(s) if s == "asc" => OrderType::ascending(),
212 Token::Literal(s) if s == "desc" => OrderType::descending(),
213 t => panic!("Expected asc or desc, got {t:?}"),
214 };
215 ColumnOrder::new(idx, order)
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumAsInner)]
221pub enum AggType {
222 Builtin(PbAggKind),
226
227 UserDefined(PbUserDefinedFunctionMetadata),
229
230 WrapScalar(PbExprNode),
232}
233
234impl Display for AggType {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 Self::Builtin(kind) => write!(f, "{}", kind.as_str_name().to_lowercase()),
238 Self::UserDefined(_) => write!(f, "udaf"),
239 Self::WrapScalar(_) => write!(f, "wrap_scalar"),
240 }
241 }
242}
243
244impl FromStr for AggType {
246 type Err = ();
247
248 fn from_str(s: &str) -> Result<Self, Self::Err> {
249 let kind = PbAggKind::from_str(s)?;
250 Ok(AggType::Builtin(kind))
251 }
252}
253
254impl From<PbAggKind> for AggType {
255 fn from(pb: PbAggKind) -> Self {
256 assert!(!matches!(
257 pb,
258 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar
259 ));
260 AggType::Builtin(pb)
261 }
262}
263
264impl AggType {
265 pub fn from_protobuf_flatten(
266 pb_kind: PbAggKind,
267 user_defined: Option<&PbUserDefinedFunctionMetadata>,
268 scalar: Option<&PbExprNode>,
269 ) -> Result<Self> {
270 match pb_kind {
271 PbAggKind::UserDefined => {
272 let user_defined = user_defined.context("expect user defined")?;
273 Ok(AggType::UserDefined(user_defined.clone()))
274 }
275 PbAggKind::WrapScalar => {
276 let scalar = scalar.context("expect scalar")?;
277 Ok(AggType::WrapScalar(scalar.clone()))
278 }
279 PbAggKind::Unspecified => bail!("Unrecognized agg."),
280 _ => Ok(AggType::Builtin(pb_kind)),
281 }
282 }
283
284 pub fn to_protobuf_simple(&self) -> PbAggKind {
285 match self {
286 Self::Builtin(pb) => *pb,
287 Self::UserDefined(_) => PbAggKind::UserDefined,
288 Self::WrapScalar(_) => PbAggKind::WrapScalar,
289 }
290 }
291
292 pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
293 match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
294 PbAggKind::Unspecified => bail!("Unrecognized agg."),
295 PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
296 PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
297 kind => Ok(AggType::Builtin(kind)),
298 }
299 }
300
301 pub fn to_protobuf(&self) -> PbAggType {
302 match self {
303 Self::Builtin(kind) => PbAggType {
304 kind: *kind as _,
305 udf_meta: None,
306 scalar_expr: None,
307 },
308 Self::UserDefined(udf_meta) => PbAggType {
309 kind: PbAggKind::UserDefined as _,
310 udf_meta: Some(udf_meta.clone()),
311 scalar_expr: None,
312 },
313 Self::WrapScalar(scalar_expr) => PbAggType {
314 kind: PbAggKind::WrapScalar as _,
315 udf_meta: None,
316 scalar_expr: Some(scalar_expr.clone()),
317 },
318 }
319 }
320}
321
322pub mod agg_types {
326 #[macro_export]
329 macro_rules! rewritten {
330 () => {
331 AggType::Builtin(
332 PbAggKind::Avg
333 | PbAggKind::StddevPop
334 | PbAggKind::StddevSamp
335 | PbAggKind::VarPop
336 | PbAggKind::VarSamp
337 | PbAggKind::Grouping
338 | PbAggKind::ApproxPercentile
341 | PbAggKind::ArgMin
342 | PbAggKind::ArgMax
343 )
344 };
345 }
346 pub use rewritten;
347
348 #[macro_export]
351 macro_rules! result_unaffected_by_order_by {
352 () => {
353 AggType::Builtin(PbAggKind::BitAnd
354 | PbAggKind::BitOr
355 | PbAggKind::BitXor | PbAggKind::BoolAnd
357 | PbAggKind::BoolOr
358 | PbAggKind::Min
359 | PbAggKind::Max
360 | PbAggKind::Sum
361 | PbAggKind::Sum0
362 | PbAggKind::Count
363 | PbAggKind::Avg
364 | PbAggKind::ApproxCountDistinct
365 | PbAggKind::VarPop
366 | PbAggKind::VarSamp
367 | PbAggKind::StddevPop
368 | PbAggKind::StddevSamp)
369 };
370 }
371 pub use result_unaffected_by_order_by;
372
373 #[macro_export]
377 macro_rules! must_have_order_by {
378 () => {
379 AggType::Builtin(
380 PbAggKind::FirstValue
381 | PbAggKind::LastValue
382 | PbAggKind::PercentileCont
383 | PbAggKind::PercentileDisc
384 | PbAggKind::Mode,
385 )
386 };
387 }
388 pub use must_have_order_by;
389
390 #[macro_export]
393 macro_rules! result_unaffected_by_distinct {
394 () => {
395 AggType::Builtin(
396 PbAggKind::BitAnd
397 | PbAggKind::BitOr
398 | PbAggKind::BoolAnd
399 | PbAggKind::BoolOr
400 | PbAggKind::Min
401 | PbAggKind::Max
402 | PbAggKind::ApproxCountDistinct,
403 )
404 };
405 }
406 pub use result_unaffected_by_distinct;
407
408 #[macro_export]
410 macro_rules! simply_cannot_two_phase {
411 () => {
412 AggType::Builtin(
413 PbAggKind::StringAgg
414 | PbAggKind::ApproxCountDistinct
415 | PbAggKind::ArrayAgg
416 | PbAggKind::JsonbAgg
417 | PbAggKind::JsonbObjectAgg
418 | PbAggKind::FirstValue
419 | PbAggKind::LastValue
420 | PbAggKind::PercentileCont
421 | PbAggKind::PercentileDisc
422 | PbAggKind::Mode
423 | PbAggKind::BoolAnd
426 | PbAggKind::BoolOr
427 | PbAggKind::BitAnd
428 | PbAggKind::BitOr
429 )
430 | AggType::UserDefined(_)
431 | AggType::WrapScalar(_)
432 };
433 }
434 pub use simply_cannot_two_phase;
435
436 #[macro_export]
439 macro_rules! single_value_state {
440 () => {
441 AggType::Builtin(
442 PbAggKind::Sum
443 | PbAggKind::Sum0
444 | PbAggKind::Count
445 | PbAggKind::BitAnd
446 | PbAggKind::BitOr
447 | PbAggKind::BitXor
448 | PbAggKind::BoolAnd
449 | PbAggKind::BoolOr
450 | PbAggKind::ApproxCountDistinct
451 | PbAggKind::InternalLastSeenValue
452 | PbAggKind::ApproxPercentile,
453 ) | AggType::UserDefined(_)
454 };
455 }
456 pub use single_value_state;
457
458 #[macro_export]
461 macro_rules! single_value_state_iff_in_append_only {
462 () => {
463 AggType::Builtin(PbAggKind::Max | PbAggKind::Min)
464 };
465 }
466 pub use single_value_state_iff_in_append_only;
467
468 #[macro_export]
470 macro_rules! materialized_input_state {
471 () => {
472 AggType::Builtin(
473 PbAggKind::Min
474 | PbAggKind::Max
475 | PbAggKind::FirstValue
476 | PbAggKind::LastValue
477 | PbAggKind::StringAgg
478 | PbAggKind::ArrayAgg
479 | PbAggKind::JsonbAgg
480 | PbAggKind::JsonbObjectAgg
481 | PbAggKind::PercentileCont
482 | PbAggKind::PercentileDisc
483 | PbAggKind::Mode,
484 ) | AggType::WrapScalar(_)
485 };
486 }
487 pub use materialized_input_state;
488
489 #[macro_export]
491 macro_rules! ordered_set {
492 () => {
493 AggType::Builtin(
494 PbAggKind::PercentileCont
495 | PbAggKind::PercentileDisc
496 | PbAggKind::Mode
497 | PbAggKind::ApproxPercentile,
498 )
499 };
500 }
501 pub use ordered_set;
502}
503
504impl AggType {
505 pub fn partial_to_total(&self) -> Option<Self> {
507 match self {
508 AggType::Builtin(
509 PbAggKind::BitXor
510 | PbAggKind::Min
511 | PbAggKind::Max
512 | PbAggKind::Sum
513 | PbAggKind::InternalLastSeenValue,
514 ) => Some(self.clone()),
515 AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => {
516 Some(Self::Builtin(PbAggKind::Sum0))
517 }
518 agg_types::simply_cannot_two_phase!() => None,
519 agg_types::rewritten!() => None,
520 AggType::Builtin(
522 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
523 ) => None,
524 }
525 }
526}
527
528#[derive(Clone, Debug, Default)]
530pub struct AggArgs {
531 data_types: Box<[DataType]>,
532 val_indices: Box<[usize]>,
533}
534
535impl AggArgs {
536 pub fn from_protobuf(args: &[PbInputRef]) -> Result<Self> {
537 Ok(AggArgs {
538 data_types: args
539 .iter()
540 .map(|arg| DataType::from(arg.get_type().unwrap()))
541 .collect(),
542 val_indices: args.iter().map(|arg| arg.get_index() as usize).collect(),
543 })
544 }
545
546 pub fn arg_types(&self) -> &[DataType] {
548 &self.data_types
549 }
550
551 pub fn val_indices(&self) -> &[usize] {
553 &self.val_indices
554 }
555}
556
557impl FromIterator<(DataType, usize)> for AggArgs {
558 fn from_iter<T: IntoIterator<Item = (DataType, usize)>>(iter: T) -> Self {
559 let (data_types, val_indices): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
560 AggArgs {
561 data_types: data_types.into(),
562 val_indices: val_indices.into(),
563 }
564 }
565}