use fixedbitset::FixedBitSet;
use itertools::Itertools;
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::{bail, bail_not_implemented, not_implemented};
use risingwave_expr::aggregate::{agg_types, AggType, PbAggKind};
use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
use super::utils::impl_distill_by_unit;
use super::{
BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef,
PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamProject, StreamShare,
StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
};
use crate::error::{ErrorCode, Result, RwError};
use crate::expr::{
AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
OrderBy, WindowFunction,
};
use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
use crate::optimizer::plan_node::generic::GenericPlanNode;
use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
use crate::optimizer::plan_node::{
gen_filter_and_pushdown, BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject,
PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
};
use crate::optimizer::property::{Distribution, Order, RequiredDist};
use crate::utils::{
ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
};
pub struct AggInfo {
pub calls: Vec<PlanAggCall>,
pub col_mapping: ColIndexMapping,
}
pub struct SeparatedAggInfo {
normal: AggInfo,
approx: AggInfo,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct LogicalAgg {
pub base: PlanBase<Logical>,
core: Agg<PlanRef>,
}
impl LogicalAgg {
fn gen_stateless_two_phase_streaming_agg_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
debug_assert!(self.group_key().is_empty());
let mut core = self.core.clone();
let (non_approx_percentile_col_mapping, approx_percentile_col_mapping, approx_percentile) =
self.prepare_approx_percentile(&mut core, stream_input.clone())?;
if core.agg_calls.is_empty() {
if let Some(approx_percentile) = approx_percentile {
return Ok(approx_percentile);
};
bail!("expected at least one agg call");
}
let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
let total_agg_calls = core
.agg_calls
.iter()
.enumerate()
.map(|(partial_output_idx, agg_call)| {
agg_call.partial_to_total_agg_call(partial_output_idx)
})
.collect_vec();
let local_agg = StreamStatelessSimpleAgg::new(core);
let exchange =
RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
let must_output_per_barrier = need_row_merge;
let global_agg = new_stream_simple_agg(
Agg::new(total_agg_calls, IndexSet::empty(), exchange),
must_output_per_barrier,
);
Self::add_row_merge_if_needed(
approx_percentile,
global_agg.into(),
approx_percentile_col_mapping,
non_approx_percentile_col_mapping,
)
}
fn gen_vnode_two_phase_streaming_agg_plan(
&self,
stream_input: PlanRef,
dist_key: &[usize],
) -> Result<PlanRef> {
let mut core = self.core.clone();
let (non_approx_percentile_col_mapping, approx_percentile_col_mapping, approx_percentile) =
self.prepare_approx_percentile(&mut core, stream_input.clone())?;
if core.agg_calls.is_empty() {
if let Some(approx_percentile) = approx_percentile {
return Ok(approx_percentile);
};
bail!("expected at least one agg call");
}
let need_row_merge = Self::need_row_merge(&approx_percentile);
let input_col_num = stream_input.schema().len(); let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
let vnode_col_idx = project.base.schema().len() - 1;
let mut local_group_key = self.group_key().clone();
local_group_key.insert(vnode_col_idx);
let n_local_group_key = local_group_key.len();
let local_agg = new_stream_hash_agg(
Agg::new(core.agg_calls.to_vec(), local_group_key, project.into()),
Some(vnode_col_idx),
);
let local_agg_group_key_cardinality = local_agg.group_key().len();
let local_group_key_without_vnode =
&local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
let global_group_key = local_agg
.i2o_col_mapping()
.rewrite_dist_key(local_group_key_without_vnode)
.expect("some input group key could not be mapped");
let global_agg = if self.group_key().is_empty() {
let exchange =
RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
let must_output_per_barrier = need_row_merge;
let global_agg = new_stream_simple_agg(
Agg::new(
core.agg_calls
.iter()
.enumerate()
.map(|(partial_output_idx, agg_call)| {
agg_call
.partial_to_total_agg_call(n_local_group_key + partial_output_idx)
})
.collect(),
global_group_key.into_iter().collect(),
exchange,
),
must_output_per_barrier,
);
global_agg.into()
} else {
assert!(!need_row_merge);
let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
.enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
let global_agg = new_stream_hash_agg(
Agg::new(
core.agg_calls
.iter()
.enumerate()
.map(|(partial_output_idx, agg_call)| {
agg_call
.partial_to_total_agg_call(n_local_group_key + partial_output_idx)
})
.collect(),
global_group_key.into_iter().collect(),
exchange,
),
None,
);
global_agg.into()
};
Self::add_row_merge_if_needed(
approx_percentile,
global_agg,
approx_percentile_col_mapping,
non_approx_percentile_col_mapping,
)
}
fn gen_single_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
let mut core = self.core.clone();
let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?;
core.input = input;
Ok(new_stream_simple_agg(core, false).into())
}
fn gen_shuffle_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
let input =
RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
.enforce_if_not_satisfies(stream_input, &Order::any())?;
let mut core = self.core.clone();
core.input = input;
Ok(new_stream_hash_agg(core, None).into())
}
fn gen_dist_stream_agg_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
use super::stream::prelude::*;
let input_dist = stream_input.distribution();
debug_assert!(*input_dist != Distribution::Broadcast);
if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
return self.gen_shuffle_plan(stream_input);
}
if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
return self.gen_single_plan(stream_input);
}
debug_assert!(if !self.group_key().is_empty() {
self.core.must_try_two_phase_agg()
} else {
self.core.can_two_phase_agg()
});
if self.group_key().is_empty()
&& self
.core
.all_local_aggs_are_stateless(stream_input.append_only())
&& input_dist.satisfies(&RequiredDist::AnyShard)
{
return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
}
let stream_input =
if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
RequiredDist::shard_by_key(
stream_input.schema().len(),
stream_input.expect_stream_key(),
)
.enforce_if_not_satisfies(stream_input, &Order::any())?
} else {
stream_input
};
let input_dist = stream_input.distribution();
match input_dist {
Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
if (self.group_key().is_empty()
|| !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
{
let dist_key = dist_key.clone();
return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
}
_ => {}
}
if !self.group_key().is_empty() {
self.gen_shuffle_plan(stream_input)
} else {
self.gen_single_plan(stream_input)
}
}
fn prepare_approx_percentile(
&self,
core: &mut Agg<PlanRef>,
stream_input: PlanRef,
) -> Result<(ColIndexMapping, ColIndexMapping, Option<PlanRef>)> {
let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
let AggInfo {
calls: non_approx_percentile_agg_calls,
col_mapping: non_approx_percentile_col_mapping,
} = normal;
let AggInfo {
calls: approx_percentile_agg_calls,
col_mapping: approx_percentile_col_mapping,
} = approx;
if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
bail_not_implemented!(
"two-phase streaming approx percentile aggregation with group key, \
please use single phase aggregation instead"
);
}
let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
&& !approx_percentile_agg_calls.is_empty())
|| approx_percentile_agg_calls.len() >= 2;
core.input = if needs_row_merge {
StreamShare::new_from_input(stream_input.clone()).into()
} else {
stream_input
};
core.agg_calls = non_approx_percentile_agg_calls;
let approx_percentile =
self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
Ok((
non_approx_percentile_col_mapping,
approx_percentile_col_mapping,
approx_percentile,
))
}
fn need_row_merge(approx_percentile: &Option<PlanRef>) -> bool {
approx_percentile.is_some()
}
fn add_row_merge_if_needed(
approx_percentile: Option<PlanRef>,
global_agg: PlanRef,
approx_percentile_col_mapping: ColIndexMapping,
non_approx_percentile_col_mapping: ColIndexMapping,
) -> Result<PlanRef> {
let need_row_merge = Self::need_row_merge(&approx_percentile);
if let Some(approx_percentile) = approx_percentile {
assert!(need_row_merge);
let row_merge = StreamRowMerge::new(
approx_percentile,
global_agg,
approx_percentile_col_mapping,
non_approx_percentile_col_mapping,
)?;
Ok(row_merge.into())
} else {
assert!(!need_row_merge);
Ok(global_agg)
}
}
fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
let estimated_len = self.agg_calls().len() - 1;
let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
approx_percentile_agg_calls.push(agg_call.clone());
approx_percentile_col_mapping.push(Some(output_idx));
} else {
non_approx_percentile_agg_calls.push(agg_call.clone());
non_approx_percentile_col_mapping.push(Some(output_idx));
}
}
let normal = AggInfo {
calls: non_approx_percentile_agg_calls,
col_mapping: ColIndexMapping::new(
non_approx_percentile_col_mapping,
self.agg_calls().len(),
),
};
let approx = AggInfo {
calls: approx_percentile_agg_calls,
col_mapping: ColIndexMapping::new(
approx_percentile_col_mapping,
self.agg_calls().len(),
),
};
SeparatedAggInfo { normal, approx }
}
fn build_approx_percentile_agg(
&self,
input: PlanRef,
approx_percentile_agg_call: &PlanAggCall,
) -> Result<PlanRef> {
let local_approx_percentile =
StreamLocalApproxPercentile::new(input, approx_percentile_agg_call);
let exchange = RequiredDist::single()
.enforce_if_not_satisfies(local_approx_percentile.into(), &Order::any())?;
let global_approx_percentile =
StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
Ok(global_approx_percentile.into())
}
fn build_approx_percentile_aggs(
&self,
input: PlanRef,
approx_percentile_agg_call: &[PlanAggCall],
) -> Result<Option<PlanRef>> {
if approx_percentile_agg_call.is_empty() {
return Ok(None);
}
let approx_percentile_plans: Vec<PlanRef> = approx_percentile_agg_call
.iter()
.map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
.try_collect()?;
assert!(!approx_percentile_plans.is_empty());
let mut iter = approx_percentile_plans.into_iter();
let mut acc = iter.next().unwrap();
for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
let new_size = current_size + 1;
let row_merge = StreamRowMerge::new(
acc,
plan,
ColIndexMapping::identity_or_none(current_size, new_size),
ColIndexMapping::new(vec![Some(current_size)], new_size),
)
.expect("failed to build row merge");
acc = row_merge.into();
}
Ok(Some(acc))
}
pub fn core(&self) -> &Agg<PlanRef> {
&self.core
}
}
pub struct LogicalAggBuilder {
input_proj_builder: ProjectBuilder,
group_key: IndexSet,
grouping_sets: Vec<IndexSet>,
agg_calls: Vec<PlanAggCall>,
error: Option<RwError>,
is_in_filter_clause: bool,
}
impl LogicalAggBuilder {
fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
let mut input_proj_builder = ProjectBuilder::default();
let mut gen_group_key_and_grouping_sets =
|grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
let grouping_sets: Vec<IndexSet> = grouping_sets
.into_iter()
.map(|set| {
set.into_iter()
.map(|expr| input_proj_builder.add_expr(&expr))
.try_collect()
.map_err(|err| not_implemented!("{err} inside GROUP BY"))
})
.try_collect()?;
let group_key = grouping_sets
.iter()
.fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
acc.union(&x.to_bitset()).collect()
});
Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
};
let (group_key, grouping_sets) = match group_by {
GroupBy::GroupKey(group_key) => {
let group_key = group_key
.into_iter()
.map(|expr| input_proj_builder.add_expr(&expr))
.try_collect()
.map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
(group_key, vec![])
}
GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
GroupBy::Rollup(rollup) => {
let grouping_sets = (0..=rollup.len())
.map(|n| {
rollup
.iter()
.take(n)
.flat_map(|x| x.iter().cloned())
.collect_vec()
})
.collect_vec();
gen_group_key_and_grouping_sets(grouping_sets)?
}
GroupBy::Cube(cube) => {
let grouping_sets = cube
.into_iter()
.powerset()
.map(|x| x.into_iter().flatten().collect_vec())
.collect_vec();
gen_group_key_and_grouping_sets(grouping_sets)?
}
};
Ok(LogicalAggBuilder {
group_key,
grouping_sets,
agg_calls: vec![],
error: None,
input_proj_builder,
is_in_filter_clause: false,
})
}
pub fn build(self, input: PlanRef) -> LogicalAgg {
let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
Agg::new(self.agg_calls, self.group_key, logical_project.into())
.with_grouping_sets(self.grouping_sets)
.into()
}
fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
let rewritten_expr = self.rewrite_expr(expr);
if let Some(error) = self.error.take() {
return Err(error);
}
Ok(rewritten_expr)
}
pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
if let Some(input_index) = self.input_proj_builder.expr_index(expr) {
if let Some(index) = self
.group_key
.indices()
.position(|group_key| group_key == input_index)
{
return Some(index);
}
}
None
}
fn schema_agg_start_offset(&self) -> usize {
self.group_key.len()
}
pub(crate) fn general_rewrite_agg_call(
agg_call: AggCall,
mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
) -> Result<ExprImpl> {
match agg_call.agg_type {
AggType::Builtin(PbAggKind::Avg) => {
assert_eq!(agg_call.args.len(), 1);
let sum = ExprImpl::from(push_agg_call(AggCall::new(
PbAggKind::Sum.into(),
agg_call.args.clone(),
agg_call.distinct,
agg_call.order_by.clone(),
agg_call.filter.clone(),
agg_call.direct_args.clone(),
)?)?)
.cast_explicit(agg_call.return_type())?;
let count = ExprImpl::from(push_agg_call(AggCall::new(
PbAggKind::Count.into(),
agg_call.args.clone(),
agg_call.distinct,
agg_call.order_by.clone(),
agg_call.filter.clone(),
agg_call.direct_args.clone(),
)?)?);
Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into())
}
AggType::Builtin(
kind @ (PbAggKind::StddevPop
| PbAggKind::StddevSamp
| PbAggKind::VarPop
| PbAggKind::VarSamp),
) => {
let arg = agg_call.args().iter().exactly_one().unwrap();
let squared_arg = ExprImpl::from(FunctionCall::new(
ExprType::Multiply,
vec![arg.clone(), arg.clone()],
)?);
let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
PbAggKind::Sum.into(),
vec![squared_arg],
agg_call.distinct,
agg_call.order_by.clone(),
agg_call.filter.clone(),
agg_call.direct_args.clone(),
)?)?)
.cast_explicit(agg_call.return_type())?;
let sum = ExprImpl::from(push_agg_call(AggCall::new(
PbAggKind::Sum.into(),
agg_call.args.clone(),
agg_call.distinct,
agg_call.order_by.clone(),
agg_call.filter.clone(),
agg_call.direct_args.clone(),
)?)?)
.cast_explicit(agg_call.return_type())?;
let count = ExprImpl::from(push_agg_call(AggCall::new(
PbAggKind::Count.into(),
agg_call.args.clone(),
agg_call.distinct,
agg_call.order_by.clone(),
agg_call.filter.clone(),
agg_call.direct_args.clone(),
)?)?);
let zero = ExprImpl::literal_int(0);
let one = ExprImpl::literal_int(1);
let squared_sum = ExprImpl::from(FunctionCall::new(
ExprType::Multiply,
vec![sum.clone(), sum],
)?);
let raw_numerator = ExprImpl::from(FunctionCall::new(
ExprType::Subtract,
vec![
sum_of_sq,
ExprImpl::from(FunctionCall::new(
ExprType::Divide,
vec![squared_sum, count.clone()],
)?),
],
)?);
let numerator_type = raw_numerator.return_type();
let numerator = ExprImpl::from(FunctionCall::new(
ExprType::Greatest,
vec![raw_numerator, zero.clone().cast_explicit(numerator_type)?],
)?);
let denominator = match kind {
PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
),
_ => unreachable!(),
};
let mut target = ExprImpl::from(FunctionCall::new(
ExprType::Divide,
vec![numerator, denominator],
)?);
if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
}
let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
let case_cond = match kind {
PbAggKind::VarPop | PbAggKind::StddevPop => {
ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
}
PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
),
_ => unreachable!(),
};
Ok(ExprImpl::from(FunctionCall::new(
ExprType::Case,
vec![case_cond, null, target],
)?))
}
AggType::Builtin(PbAggKind::ApproxPercentile) => {
if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
let prev_percentile = agg_call.direct_args[0].clone();
let new_percentile = 1.0
- prev_percentile
.get_data()
.as_ref()
.unwrap()
.as_float64()
.into_inner();
let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
let new_percentile = Literal::new(new_percentile, DataType::Float64);
let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
let new_agg_call = AggCall {
order_by: OrderBy::any(),
direct_args: new_direct_args,
..agg_call
};
Ok(push_agg_call(new_agg_call)?.into())
} else {
let new_agg_call = AggCall {
order_by: OrderBy::any(),
..agg_call
};
Ok(push_agg_call(new_agg_call)?.into())
}
}
_ => Ok(push_agg_call(agg_call)?.into()),
}
}
fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
let AggCall {
agg_type,
return_type,
args,
distinct,
order_by,
filter,
direct_args,
} = agg_call;
self.is_in_filter_clause = true;
let filter = filter.rewrite_expr(self);
self.is_in_filter_clause = false;
let args: Vec<_> = args
.iter()
.map(|expr| {
let index = self.input_proj_builder.add_expr(expr)?;
Ok(InputRef::new(index, expr.return_type()))
})
.try_collect()
.map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
let order_by: Vec<_> = order_by
.sort_exprs
.iter()
.map(|e| {
let index = self.input_proj_builder.add_expr(&e.expr)?;
Ok(ColumnOrder::new(index, e.order_type))
})
.try_collect()
.map_err(|err: &'static str| {
not_implemented!("{err} inside aggregation calls order by")
})?;
let plan_agg_call = PlanAggCall {
agg_type,
return_type: return_type.clone(),
inputs: args,
distinct,
order_by,
filter,
direct_args,
};
if let Some((pos, existing)) = self
.agg_calls
.iter()
.find_position(|&c| c == &plan_agg_call)
{
return Ok(InputRef::new(
self.schema_agg_start_offset() + pos,
existing.return_type.clone(),
));
}
let index = self.schema_agg_start_offset() + self.agg_calls.len();
self.agg_calls.push(plan_agg_call);
Ok(InputRef::new(index, return_type))
}
fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
&& agg_call.order_by.sort_exprs.is_empty()
{
return Err(ErrorCode::InvalidInputSyntax(format!(
"Aggregation function {} requires ORDER BY clause",
agg_call.agg_type
))
.into());
}
if matches!(
agg_call.agg_type,
agg_types::result_unaffected_by_order_by!()
) {
agg_call.order_by = OrderBy::any();
}
if matches!(
agg_call.agg_type,
agg_types::result_unaffected_by_distinct!()
) {
agg_call.distinct = false;
}
if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
if self.grouping_sets.is_empty() {
return Err(ErrorCode::NotSupported(
"GROUPING must be used in a query with grouping sets".into(),
"try to use grouping sets instead".into(),
)
.into());
}
if agg_call.args.len() >= 32 {
return Err(ErrorCode::InvalidInputSyntax(
"GROUPING must have fewer than 32 arguments".into(),
)
.into());
}
if agg_call
.args
.iter()
.any(|x| self.try_as_group_expr(x).is_none())
{
return Err(ErrorCode::InvalidInputSyntax(
"arguments to GROUPING must be grouping expressions of the associated query level"
.into(),
).into());
}
}
Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
}
}
impl ExprRewriter for LogicalAggBuilder {
fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
let dummy = Literal::new(None, agg_call.return_type()).into();
match self.try_rewrite_agg_call(agg_call) {
Ok(expr) => expr,
Err(err) => {
self.error = Some(err);
dummy
}
}
}
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
let expr = func_call.into();
if let Some(group_key) = self.try_as_group_expr(&expr) {
InputRef::new(group_key, expr.return_type()).into()
} else {
let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
let inputs = inputs
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
FunctionCall::new_unchecked(func_type, inputs, ret).into()
}
}
fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
let WindowFunction {
args,
partition_by,
order_by,
..
} = window_func;
let args = args
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
let partition_by = partition_by
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
let order_by = order_by.rewrite_expr(self);
WindowFunction {
args,
partition_by,
order_by,
..window_func
}
.into()
}
fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
let expr = input_ref.into();
if let Some(group_key) = self.try_as_group_expr(&expr) {
InputRef::new(group_key, expr.return_type()).into()
} else if self.is_in_filter_clause {
InputRef::new(
self.input_proj_builder.add_expr(&expr).unwrap(),
expr.return_type(),
)
.into()
} else {
self.error = Some(
ErrorCode::InvalidInputSyntax(
"column must appear in the GROUP BY clause or be used in an aggregate function"
.into(),
)
.into(),
);
expr
}
}
fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
if subquery.is_correlated(0) {
self.error = Some(
not_implemented!(
issue = 2275,
"correlated subquery in HAVING or SELECT with agg",
)
.into(),
);
}
subquery.into()
}
}
impl From<Agg<PlanRef>> for LogicalAgg {
fn from(core: Agg<PlanRef>) -> Self {
let base = PlanBase::new_logical_with_core(&core);
Self { base, core }
}
}
impl From<Agg<PlanRef>> for PlanRef {
fn from(core: Agg<PlanRef>) -> Self {
LogicalAgg::from(core).into()
}
}
impl LogicalAgg {
pub fn i2o_col_mapping(&self) -> ColIndexMapping {
self.core.i2o_col_mapping()
}
pub fn create(
select_exprs: Vec<ExprImpl>,
group_by: GroupBy,
having: Option<ExprImpl>,
input: PlanRef,
) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
let rewritten_select_exprs = select_exprs
.into_iter()
.map(|expr| agg_builder.rewrite_with_error(expr))
.collect::<Result<_>>()?;
let rewritten_having = having
.map(|expr| agg_builder.rewrite_with_error(expr))
.transpose()?;
Ok((
agg_builder.build(input).into(),
rewritten_select_exprs,
rewritten_having,
))
}
pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
&self.core.agg_calls
}
pub fn group_key(&self) -> &IndexSet {
&self.core.group_key
}
pub fn grouping_sets(&self) -> &Vec<IndexSet> {
&self.core.grouping_sets
}
pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
self.core.decompose()
}
#[must_use]
pub fn rewrite_with_input_agg(
&self,
input: PlanRef,
agg_calls: &[PlanAggCall],
mut input_col_change: ColIndexMapping,
) -> (Self, ColIndexMapping) {
let agg_calls = agg_calls
.iter()
.cloned()
.map(|mut agg_call| {
agg_call.inputs.iter_mut().for_each(|i| {
*i = InputRef::new(input_col_change.map(i.index()), i.return_type())
});
agg_call.order_by.iter_mut().for_each(|o| {
o.column_index = input_col_change.map(o.column_index);
});
agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
agg_call
})
.collect();
let group_key_in_vec: Vec<usize> = self
.group_key()
.indices()
.map(|key| input_col_change.map(key))
.collect();
let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
let grouping_sets = self
.grouping_sets()
.iter()
.map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
.collect();
let new_agg = Agg::new(agg_calls, group_key.clone(), input)
.with_grouping_sets(grouping_sets)
.with_enable_two_phase(self.core().enable_two_phase);
let mut out_col_change = vec![];
for idx in group_key_in_vec {
let pos = group_key.indices().position(|x| x == idx).unwrap();
out_col_change.push(pos);
}
for i in (group_key.len())..new_agg.schema().len() {
out_col_change.push(i);
}
let out_col_change =
ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
(new_agg.into(), out_col_change)
}
}
impl PlanTreeNodeUnary for LogicalAgg {
fn input(&self) -> PlanRef {
self.core.input.clone()
}
fn clone_with_input(&self, input: PlanRef) -> Self {
Agg::new(self.agg_calls().to_vec(), self.group_key().clone(), input)
.with_grouping_sets(self.grouping_sets().clone())
.with_enable_two_phase(self.core().enable_two_phase)
.into()
}
#[must_use]
fn rewrite_with_input(
&self,
input: PlanRef,
input_col_change: ColIndexMapping,
) -> (Self, ColIndexMapping) {
self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
}
}
impl_plan_tree_node_for_unary! {LogicalAgg}
impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
impl ExprRewritable for LogicalAgg {
fn has_rewritable_expr(&self) -> bool {
true
}
fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
let mut core = self.core.clone();
core.rewrite_exprs(r);
Self {
base: self.base.clone_with_new_plan_id(),
core,
}
.into()
}
}
impl ExprVisitable for LogicalAgg {
fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
self.core.visit_exprs(v);
}
}
impl ColPrunable for LogicalAgg {
fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
let group_key_required_cols = self.group_key().to_bitset();
let (agg_call_required_cols, agg_calls) = {
let input_cnt = self.input().schema().len();
let mut tmp = FixedBitSet::with_capacity(input_cnt);
let group_key_cardinality = self.group_key().len();
let new_agg_calls = required_cols
.iter()
.filter(|&&index| index >= group_key_cardinality)
.map(|&index| {
let index = index - group_key_cardinality;
let agg_call = self.agg_calls()[index].clone();
tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
for i in &agg_call.filter.conjunctions {
tmp.union_with(&i.collect_input_refs(input_cnt));
}
agg_call
})
.collect_vec();
(tmp, new_agg_calls)
};
let input_required_cols = {
let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
tmp.union_with(&group_key_required_cols);
tmp.union_with(&agg_call_required_cols);
tmp.ones().collect_vec()
};
let input_col_change = ColIndexMapping::with_remaining_columns(
&input_required_cols,
self.input().schema().len(),
);
let agg = {
let input = self.input().prune_col(&input_required_cols, ctx);
let (agg, output_col_change) =
self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
assert!(output_col_change.is_identity());
agg
};
let new_output_cols = {
let group_key_cardinality = agg.group_key().len();
let mut tmp = (0..group_key_cardinality).collect_vec();
tmp.extend(
required_cols
.iter()
.filter(|&&index| index >= group_key_cardinality),
);
tmp
};
if new_output_cols == required_cols {
agg.into()
} else {
let mapping =
&ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
let output_required_cols = required_cols
.iter()
.map(|&idx| mapping.map(idx))
.collect_vec();
let src_size = agg.schema().len();
LogicalProject::with_mapping(
agg.into(),
ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
)
.into()
}
}
}
impl PredicatePushdown for LogicalAgg {
fn predicate_pushdown(
&self,
predicate: Condition,
ctx: &mut PredicatePushdownContext,
) -> PlanRef {
let num_group_key = self.group_key().len();
let num_agg_calls = self.agg_calls().len();
assert!(num_group_key + num_agg_calls == self.schema().len());
if num_group_key == 0 {
return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
}
let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
let mut subst = Substitute {
mapping: self
.group_key()
.indices()
.enumerate()
.map(|(i, group_key)| {
InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
})
.collect(),
};
let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
}
}
impl ToBatch for LogicalAgg {
fn to_batch(&self) -> Result<PlanRef> {
self.to_batch_with_order_required(&Order::any())
}
fn to_batch_with_order_required(&self, required_order: &Order) -> Result<PlanRef> {
let input = self.input().to_batch()?;
let new_logical = Agg {
input,
..self.core.clone()
};
let agg_plan = if self.group_key().is_empty() {
BatchSimpleAgg::new(new_logical).into()
} else if self.ctx().session_ctx().config().batch_enable_sort_agg()
&& new_logical.input_provides_order_on_group_keys()
{
BatchSortAgg::new(new_logical).into()
} else {
BatchHashAgg::new(new_logical).into()
};
required_order.enforce_if_not_satisfies(agg_plan)
}
}
fn find_or_append_row_count(mut logical: Agg<PlanRef>) -> (Agg<PlanRef>, usize) {
let count_star = PlanAggCall::count_star();
let row_count_idx = if let Some((idx, _)) = logical
.agg_calls
.iter()
.find_position(|&c| c == &count_star)
{
idx
} else {
let idx = logical.agg_calls.len();
logical.agg_calls.push(count_star);
idx
};
(logical, row_count_idx)
}
fn new_stream_simple_agg(core: Agg<PlanRef>, must_output_per_barrier: bool) -> StreamSimpleAgg {
let (logical, row_count_idx) = find_or_append_row_count(core);
StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
}
fn new_stream_hash_agg(core: Agg<PlanRef>, vnode_col_idx: Option<usize>) -> StreamHashAgg {
let (logical, row_count_idx) = find_or_append_row_count(core);
StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
}
impl ToStream for LogicalAgg {
fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
use super::stream::prelude::*;
for agg_call in self.agg_calls() {
if matches!(agg_call.agg_type, agg_types::unimplemented_in_stream!()) {
bail_not_implemented!("{} aggregation in materialized view", agg_call.agg_type);
}
}
let eowc = ctx.emit_on_window_close();
let stream_input = self.input().to_stream(ctx)?;
if stream_input.append_only() && self.agg_calls().is_empty() {
let input = if self.group_key().len() != self.input().schema().len() {
let cols = &self.group_key().to_vec();
LogicalProject::with_mapping(
self.input(),
ColIndexMapping::with_remaining_columns(cols, self.input().schema().len()),
)
.into()
} else {
self.input()
};
let input_schema_len = input.schema().len();
let logical_dedup = LogicalDedup::new(input, (0..input_schema_len).collect());
return logical_dedup.to_stream(ctx);
}
let plan = self.gen_dist_stream_agg_plan(stream_input)?;
let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
if eowc {
return Err(ErrorCode::InvalidInputSyntax(
"`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
.to_string(),
)
.into());
}
(plan.clone(), final_agg.agg_calls().len())
} else if let Some(final_agg) = plan.as_stream_hash_agg() {
(
if eowc {
final_agg.to_eowc_version()?
} else {
plan.clone()
},
final_agg.agg_calls().len(),
)
} else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
if eowc {
return Err(ErrorCode::InvalidInputSyntax(
"`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
.to_string(),
)
.into());
}
(plan.clone(), 1)
} else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
if eowc {
return Err(ErrorCode::InvalidInputSyntax(
"`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
.to_string(),
)
.into());
}
(plan.clone(), stream_row_merge.base.schema().len())
} else {
panic!("the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge");
};
if self.agg_calls().len() == n_final_agg_calls {
Ok(plan)
} else {
assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
Ok(StreamProject::new(generic::Project::with_out_col_idx(
plan,
0..self.schema().len(),
))
.with_noop_update_hint(self.agg_calls().is_empty())
.into())
}
}
fn logical_rewrite_for_stream(
&self,
ctx: &mut RewriteStreamContext,
) -> Result<(PlanRef, ColIndexMapping)> {
let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
let (map, _) = out_col_change.into_parts();
let out_col_change = ColIndexMapping::new(map, agg.schema().len());
Ok((agg.into(), out_col_change))
}
}
#[cfg(test)]
mod tests {
use risingwave_common::catalog::{Field, Schema};
use super::*;
use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
use crate::optimizer::optimizer_context::OptimizerContext;
use crate::optimizer::plan_node::LogicalValues;
#[tokio::test]
async fn test_create() {
let ty = DataType::Int32;
let ctx = OptimizerContext::mock().await;
let fields: Vec<Field> = vec![
Field::with_name(ty.clone(), "v1"),
Field::with_name(ty.clone(), "v2"),
Field::with_name(ty.clone(), "v3"),
];
let values = LogicalValues::new(vec![], Schema { fields }, ctx);
let input = PlanRef::from(values);
let input_ref_1 = InputRef::new(0, ty.clone());
let input_ref_2 = InputRef::new(1, ty.clone());
let input_ref_3 = InputRef::new(2, ty.clone());
let gen_internal_value = |select_exprs: Vec<ExprImpl>,
group_exprs|
-> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
let (plan, exprs, _) = LogicalAgg::create(
select_exprs,
GroupBy::GroupKey(group_exprs),
None,
input.clone(),
)
.unwrap();
let logical_agg = plan.as_logical_agg().unwrap();
let agg_calls = logical_agg.agg_calls().to_vec();
let group_key = logical_agg.group_key().clone();
(exprs, agg_calls, group_key)
};
{
let select_exprs = vec![input_ref_1.clone().into()];
let group_exprs = vec![input_ref_1.clone().into()];
let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
assert_eq!(exprs.len(), 1);
assert_eq_input_ref!(&exprs[0], 0);
assert_eq!(agg_calls.len(), 0);
assert_eq!(group_key, vec![0].into());
}
{
let min_v2 = AggCall::new(
PbAggKind::Min.into(),
vec![input_ref_2.clone().into()],
false,
OrderBy::any(),
Condition::true_cond(),
vec![],
)
.unwrap();
let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
let group_exprs = vec![input_ref_1.clone().into()];
let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
assert_eq!(exprs.len(), 2);
assert_eq_input_ref!(&exprs[0], 0);
assert_eq_input_ref!(&exprs[1], 1);
assert_eq!(agg_calls.len(), 1);
assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
assert_eq!(group_key, vec![0].into());
}
{
let min_v2 = AggCall::new(
PbAggKind::Min.into(),
vec![input_ref_2.clone().into()],
false,
OrderBy::any(),
Condition::true_cond(),
vec![],
)
.unwrap();
let max_v3 = AggCall::new(
PbAggKind::Max.into(),
vec![input_ref_3.clone().into()],
false,
OrderBy::any(),
Condition::true_cond(),
vec![],
)
.unwrap();
let func_call =
FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
let group_exprs = vec![input_ref_1.clone().into()];
let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
assert_eq_input_ref!(&exprs[0], 0);
if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
assert_eq!(func_call.func_type(), ExprType::Add);
let inputs = func_call.inputs();
assert_eq_input_ref!(&inputs[0], 1);
assert_eq_input_ref!(&inputs[1], 2);
} else {
panic!("Wrong expression type!");
}
assert_eq!(agg_calls.len(), 2);
assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
assert_eq!(group_key, vec![0].into());
}
{
let v1_mult_v3 = FunctionCall::new(
ExprType::Multiply,
vec![input_ref_1.into(), input_ref_3.into()],
)
.unwrap();
let agg_call = AggCall::new(
PbAggKind::Min.into(),
vec![v1_mult_v3.into()],
false,
OrderBy::any(),
Condition::true_cond(),
vec![],
)
.unwrap();
let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
let group_exprs = vec![input_ref_2.into()];
let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
assert_eq_input_ref!(&exprs[0], 0);
assert_eq_input_ref!(&exprs[1], 1);
assert_eq!(agg_calls.len(), 1);
assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
assert_eq!(group_key, vec![0].into());
}
}
async fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
let ctx = OptimizerContext::mock().await;
let values = LogicalValues::new(vec![], Schema { fields }, ctx);
let agg_call = PlanAggCall {
agg_type: PbAggKind::Min.into(),
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
order_by: vec![],
filter: Condition::true_cond(),
direct_args: vec![],
};
Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
}
#[tokio::test]
async fn test_prune_all() {
let ty = DataType::Int32;
let fields: Vec<Field> = vec![
Field::with_name(ty.clone(), "v1"),
Field::with_name(ty.clone(), "v2"),
Field::with_name(ty.clone(), "v3"),
];
let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
let required_cols = vec![0, 1];
let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
let agg_new = plan.as_logical_agg().unwrap();
assert_eq!(agg_new.group_key(), &vec![0].into());
assert_eq!(agg_new.agg_calls().len(), 1);
let agg_call_new = agg_new.agg_calls()[0].clone();
assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
assert_eq!(agg_call_new.return_type, ty);
let values = agg_new.input();
let values = values.as_logical_values().unwrap();
assert_eq!(values.schema().fields(), &fields[1..]);
}
#[tokio::test]
async fn test_prune_all_with_order_required() {
let ty = DataType::Int32;
let fields: Vec<Field> = vec![
Field::with_name(ty.clone(), "v1"),
Field::with_name(ty.clone(), "v2"),
Field::with_name(ty.clone(), "v3"),
];
let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
let required_cols = vec![1, 0];
let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
let proj = plan.as_logical_project().unwrap();
assert_eq!(proj.exprs().len(), 2);
assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
let proj_input = proj.input();
let agg_new = proj_input.as_logical_agg().unwrap();
assert_eq!(agg_new.group_key(), &vec![0].into());
assert_eq!(agg_new.agg_calls().len(), 1);
let agg_call_new = agg_new.agg_calls()[0].clone();
assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
assert_eq!(agg_call_new.return_type, ty);
let values = agg_new.input();
let values = values.as_logical_values().unwrap();
assert_eq!(values.schema().fields(), &fields[1..]);
}
#[tokio::test]
async fn test_prune_group_key() {
let ctx = OptimizerContext::mock().await;
let ty = DataType::Int32;
let fields: Vec<Field> = vec![
Field::with_name(ty.clone(), "v1"),
Field::with_name(ty.clone(), "v2"),
Field::with_name(ty.clone(), "v3"),
];
let values: LogicalValues = LogicalValues::new(
vec![],
Schema {
fields: fields.clone(),
},
ctx,
);
let agg_call = PlanAggCall {
agg_type: PbAggKind::Min.into(),
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
order_by: vec![],
filter: Condition::true_cond(),
direct_args: vec![],
};
let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
let required_cols = vec![1];
let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
let project = plan.as_logical_project().unwrap();
assert_eq!(project.exprs().len(), 1);
assert_eq_input_ref!(&project.exprs()[0], 1);
let agg_new = project.input();
let agg_new = agg_new.as_logical_agg().unwrap();
assert_eq!(agg_new.group_key(), &vec![0].into());
assert_eq!(agg_new.agg_calls().len(), 1);
let agg_call_new = agg_new.agg_calls()[0].clone();
assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
assert_eq!(agg_call_new.return_type, ty);
let values = agg_new.input();
let values = values.as_logical_values().unwrap();
assert_eq!(values.schema().fields(), &fields[1..]);
}
#[tokio::test]
async fn test_prune_agg() {
let ty = DataType::Int32;
let ctx = OptimizerContext::mock().await;
let fields: Vec<Field> = vec![
Field::with_name(ty.clone(), "v1"),
Field::with_name(ty.clone(), "v2"),
Field::with_name(ty.clone(), "v3"),
];
let values = LogicalValues::new(
vec![],
Schema {
fields: fields.clone(),
},
ctx,
);
let agg_calls = vec![
PlanAggCall {
agg_type: PbAggKind::Min.into(),
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
order_by: vec![],
filter: Condition::true_cond(),
direct_args: vec![],
},
PlanAggCall {
agg_type: PbAggKind::Max.into(),
return_type: ty.clone(),
inputs: vec![InputRef::new(1, ty.clone())],
distinct: false,
order_by: vec![],
filter: Condition::true_cond(),
direct_args: vec![],
},
];
let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
let required_cols = vec![0, 3];
let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
let project = plan.as_logical_project().unwrap();
assert_eq!(project.exprs().len(), 2);
assert_eq_input_ref!(&project.exprs()[0], 0);
assert_eq_input_ref!(&project.exprs()[1], 2);
let agg_new = project.input();
let agg_new = agg_new.as_logical_agg().unwrap();
assert_eq!(agg_new.group_key(), &vec![0, 1].into());
assert_eq!(agg_new.agg_calls().len(), 1);
let agg_call_new = agg_new.agg_calls()[0].clone();
assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
assert_eq!(agg_call_new.return_type, ty);
let values = agg_new.input();
let values = values.as_logical_values().unwrap();
assert_eq!(values.schema().fields(), &fields[1..]);
}
}