risingwave_expr/aggregate/
def.rsuse std::fmt::Display;
use std::iter::Peekable;
use std::str::FromStr;
use std::sync::Arc;
use anyhow::Context;
use enum_as_inner::EnumAsInner;
use itertools::Itertools;
use risingwave_common::bail;
use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding::DatumFromProtoExt;
pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
use risingwave_pb::expr::{
PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
};
use crate::expr::{
build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token,
};
use crate::Result;
#[derive(Debug, Clone)]
pub struct AggCall {
pub agg_type: AggType,
pub args: AggArgs,
pub return_type: DataType,
pub column_orders: Vec<ColumnOrder>,
pub filter: Option<Arc<dyn Expression>>,
pub distinct: bool,
pub direct_args: Vec<LiteralExpression>,
}
impl AggCall {
pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
let agg_type = AggType::from_protobuf_flatten(
agg_call.get_kind()?,
agg_call.udf.as_ref(),
agg_call.scalar.as_ref(),
)?;
let args = AggArgs::from_protobuf(agg_call.get_args())?;
let column_orders = agg_call
.get_order_by()
.iter()
.map(|col_order| {
let col_idx = col_order.get_column_index() as usize;
let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap());
ColumnOrder::new(col_idx, order_type)
})
.collect();
let filter = match agg_call.filter {
Some(ref pb_filter) => Some(build_from_prost(pb_filter)?.into()), None => None,
};
let direct_args = agg_call
.direct_args
.iter()
.map(|arg| {
let data_type = DataType::from(arg.get_type().unwrap());
LiteralExpression::new(
data_type.clone(),
Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(),
)
})
.collect_vec();
Ok(AggCall {
agg_type,
args,
return_type: DataType::from(agg_call.get_return_type()?),
column_orders,
filter,
distinct: agg_call.distinct,
direct_args,
})
}
pub fn from_pretty(s: impl AsRef<str>) -> Self {
let tokens = crate::expr::lexer(s.as_ref());
Parser::new(tokens.into_iter()).parse_aggregation()
}
pub fn with_filter(mut self, filter: BoxedExpression) -> Self {
self.filter = Some(filter.into());
self
}
}
struct Parser<Iter: Iterator> {
tokens: Peekable<Iter>,
}
impl<Iter: Iterator<Item = Token>> Parser<Iter> {
fn new(tokens: Iter) -> Self {
Self {
tokens: tokens.peekable(),
}
}
fn parse_aggregation(&mut self) -> AggCall {
assert_eq!(self.tokens.next(), Some(Token::LParen), "Expected a (");
let func = self.parse_function();
assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
let ty = self.parse_type();
let mut distinct = false;
let mut children = Vec::new();
let mut column_orders = Vec::new();
while matches!(self.tokens.peek(), Some(Token::Index(_))) {
children.push(self.parse_arg());
}
if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "distinct") {
distinct = true;
self.tokens.next(); }
if matches!(self.tokens.peek(), Some(Token::Literal(s)) if s == "orderby") {
self.tokens.next(); while matches!(self.tokens.peek(), Some(Token::Index(_))) {
column_orders.push(self.parse_orderkey());
}
}
self.tokens.next(); AggCall {
agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
args: AggArgs {
data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
val_indices: children.iter().map(|(idx, _)| *idx).collect(),
},
return_type: ty,
column_orders,
filter: None,
distinct,
direct_args: Vec::new(),
}
}
fn parse_type(&mut self) -> DataType {
match self.tokens.next().expect("Unexpected end of input") {
Token::Literal(name) => name.parse::<DataType>().expect_str("type", &name),
t => panic!("Expected a Literal, got {t:?}"),
}
}
fn parse_arg(&mut self) -> (usize, DataType) {
let idx = match self.tokens.next().expect("Unexpected end of input") {
Token::Index(idx) => idx,
t => panic!("Expected an Index, got {t:?}"),
};
assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
let ty = self.parse_type();
(idx, ty)
}
fn parse_function(&mut self) -> PbAggKind {
match self.tokens.next().expect("Unexpected end of input") {
Token::Literal(name) => {
PbAggKind::from_str_name(&name.to_uppercase()).expect_str("function", &name)
}
t => panic!("Expected a Literal, got {t:?}"),
}
}
fn parse_orderkey(&mut self) -> ColumnOrder {
let idx = match self.tokens.next().expect("Unexpected end of input") {
Token::Index(idx) => idx,
t => panic!("Expected an Index, got {t:?}"),
};
assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
let order = match self.tokens.next().expect("Unexpected end of input") {
Token::Literal(s) if s == "asc" => OrderType::ascending(),
Token::Literal(s) if s == "desc" => OrderType::descending(),
t => panic!("Expected asc or desc, got {t:?}"),
};
ColumnOrder::new(idx, order)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumAsInner)]
pub enum AggType {
Builtin(PbAggKind),
UserDefined(PbUserDefinedFunctionMetadata),
WrapScalar(PbExprNode),
}
impl Display for AggType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Builtin(kind) => write!(f, "{}", kind.as_str_name().to_lowercase()),
Self::UserDefined(_) => write!(f, "udaf"),
Self::WrapScalar(_) => write!(f, "wrap_scalar"),
}
}
}
impl FromStr for AggType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let kind = PbAggKind::from_str(s)?;
Ok(AggType::Builtin(kind))
}
}
impl From<PbAggKind> for AggType {
fn from(pb: PbAggKind) -> Self {
assert!(!matches!(
pb,
PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar
));
AggType::Builtin(pb)
}
}
impl AggType {
pub fn from_protobuf_flatten(
pb_kind: PbAggKind,
user_defined: Option<&PbUserDefinedFunctionMetadata>,
scalar: Option<&PbExprNode>,
) -> Result<Self> {
match pb_kind {
PbAggKind::UserDefined => {
let user_defined = user_defined.context("expect user defined")?;
Ok(AggType::UserDefined(user_defined.clone()))
}
PbAggKind::WrapScalar => {
let scalar = scalar.context("expect scalar")?;
Ok(AggType::WrapScalar(scalar.clone()))
}
PbAggKind::Unspecified => bail!("Unrecognized agg."),
_ => Ok(AggType::Builtin(pb_kind)),
}
}
pub fn to_protobuf_simple(&self) -> PbAggKind {
match self {
Self::Builtin(pb) => *pb,
Self::UserDefined(_) => PbAggKind::UserDefined,
Self::WrapScalar(_) => PbAggKind::WrapScalar,
}
}
pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
PbAggKind::Unspecified => bail!("Unrecognized agg."),
PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
kind => Ok(AggType::Builtin(kind)),
}
}
pub fn to_protobuf(&self) -> PbAggType {
match self {
Self::Builtin(kind) => PbAggType {
kind: *kind as _,
udf_meta: None,
scalar_expr: None,
},
Self::UserDefined(udf_meta) => PbAggType {
kind: PbAggKind::UserDefined as _,
udf_meta: Some(udf_meta.clone()),
scalar_expr: None,
},
Self::WrapScalar(scalar_expr) => PbAggType {
kind: PbAggKind::WrapScalar as _,
udf_meta: None,
scalar_expr: Some(scalar_expr.clone()),
},
}
}
}
pub mod agg_types {
#[macro_export]
macro_rules! unimplemented_in_stream {
() => {
AggType::Builtin(
PbAggKind::PercentileCont | PbAggKind::PercentileDisc | PbAggKind::Mode,
)
};
}
pub use unimplemented_in_stream;
#[macro_export]
macro_rules! rewritten {
() => {
AggType::Builtin(
PbAggKind::Avg
| PbAggKind::StddevPop
| PbAggKind::StddevSamp
| PbAggKind::VarPop
| PbAggKind::VarSamp
| PbAggKind::Grouping
| PbAggKind::ApproxPercentile
)
};
}
pub use rewritten;
#[macro_export]
macro_rules! result_unaffected_by_order_by {
() => {
AggType::Builtin(PbAggKind::BitAnd
| PbAggKind::BitOr
| PbAggKind::BitXor | PbAggKind::BoolAnd
| PbAggKind::BoolOr
| PbAggKind::Min
| PbAggKind::Max
| PbAggKind::Sum
| PbAggKind::Sum0
| PbAggKind::Count
| PbAggKind::Avg
| PbAggKind::ApproxCountDistinct
| PbAggKind::VarPop
| PbAggKind::VarSamp
| PbAggKind::StddevPop
| PbAggKind::StddevSamp)
};
}
pub use result_unaffected_by_order_by;
#[macro_export]
macro_rules! must_have_order_by {
() => {
AggType::Builtin(
PbAggKind::FirstValue
| PbAggKind::LastValue
| PbAggKind::PercentileCont
| PbAggKind::PercentileDisc
| PbAggKind::Mode,
)
};
}
pub use must_have_order_by;
#[macro_export]
macro_rules! result_unaffected_by_distinct {
() => {
AggType::Builtin(
PbAggKind::BitAnd
| PbAggKind::BitOr
| PbAggKind::BoolAnd
| PbAggKind::BoolOr
| PbAggKind::Min
| PbAggKind::Max
| PbAggKind::ApproxCountDistinct,
)
};
}
pub use result_unaffected_by_distinct;
#[macro_export]
macro_rules! simply_cannot_two_phase {
() => {
AggType::Builtin(
PbAggKind::StringAgg
| PbAggKind::ApproxCountDistinct
| PbAggKind::ArrayAgg
| PbAggKind::JsonbAgg
| PbAggKind::JsonbObjectAgg
| PbAggKind::FirstValue
| PbAggKind::LastValue
| PbAggKind::PercentileCont
| PbAggKind::PercentileDisc
| PbAggKind::Mode
| PbAggKind::BoolAnd
| PbAggKind::BoolOr
| PbAggKind::BitAnd
| PbAggKind::BitOr
)
| AggType::UserDefined(_)
| AggType::WrapScalar(_)
};
}
pub use simply_cannot_two_phase;
#[macro_export]
macro_rules! single_value_state {
() => {
AggType::Builtin(
PbAggKind::Sum
| PbAggKind::Sum0
| PbAggKind::Count
| PbAggKind::BitAnd
| PbAggKind::BitOr
| PbAggKind::BitXor
| PbAggKind::BoolAnd
| PbAggKind::BoolOr
| PbAggKind::ApproxCountDistinct
| PbAggKind::InternalLastSeenValue
| PbAggKind::ApproxPercentile,
) | AggType::UserDefined(_)
};
}
pub use single_value_state;
#[macro_export]
macro_rules! single_value_state_iff_in_append_only {
() => {
AggType::Builtin(PbAggKind::Max | PbAggKind::Min)
};
}
pub use single_value_state_iff_in_append_only;
#[macro_export]
macro_rules! materialized_input_state {
() => {
AggType::Builtin(
PbAggKind::Min
| PbAggKind::Max
| PbAggKind::FirstValue
| PbAggKind::LastValue
| PbAggKind::StringAgg
| PbAggKind::ArrayAgg
| PbAggKind::JsonbAgg
| PbAggKind::JsonbObjectAgg,
) | AggType::WrapScalar(_)
};
}
pub use materialized_input_state;
#[macro_export]
macro_rules! ordered_set {
() => {
AggType::Builtin(
PbAggKind::PercentileCont
| PbAggKind::PercentileDisc
| PbAggKind::Mode
| PbAggKind::ApproxPercentile,
)
};
}
pub use ordered_set;
}
impl AggType {
pub fn partial_to_total(&self) -> Option<Self> {
match self {
AggType::Builtin(
PbAggKind::BitXor
| PbAggKind::Min
| PbAggKind::Max
| PbAggKind::Sum
| PbAggKind::InternalLastSeenValue,
) => Some(self.clone()),
AggType::Builtin(PbAggKind::Sum0 | PbAggKind::Count) => {
Some(Self::Builtin(PbAggKind::Sum0))
}
agg_types::simply_cannot_two_phase!() => None,
agg_types::rewritten!() => None,
AggType::Builtin(
PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
) => None,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct AggArgs {
data_types: Box<[DataType]>,
val_indices: Box<[usize]>,
}
impl AggArgs {
pub fn from_protobuf(args: &[PbInputRef]) -> Result<Self> {
Ok(AggArgs {
data_types: args
.iter()
.map(|arg| DataType::from(arg.get_type().unwrap()))
.collect(),
val_indices: args.iter().map(|arg| arg.get_index() as usize).collect(),
})
}
pub fn arg_types(&self) -> &[DataType] {
&self.data_types
}
pub fn val_indices(&self) -> &[usize] {
&self.val_indices
}
}
impl FromIterator<(DataType, usize)> for AggArgs {
fn from_iter<T: IntoIterator<Item = (DataType, usize)>>(iter: T) -> Self {
let (data_types, val_indices): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
AggArgs {
data_types: data_types.into(),
val_indices: val_indices.into(),
}
}
}