risingwave_frontend/optimizer/rule/agg_project_merge_rule.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use itertools::Itertools;
use super::super::plan_node::*;
use super::{BoxedRule, Rule};
use crate::utils::IndexSet;
/// Merge [`LogicalAgg`] <- [`LogicalProject`] to [`LogicalAgg`].
pub struct AggProjectMergeRule {}
impl Rule for AggProjectMergeRule {
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let agg = plan.as_logical_agg()?;
let agg = agg.core().clone();
assert!(agg.grouping_sets.is_empty());
let old_input = agg.input.clone();
let proj = old_input.as_logical_project()?;
// only apply when the input proj is all input-ref
if !proj.is_all_inputref() {
return None;
}
let proj_o2i = proj.o2i_col_mapping();
// modify group key according to projection
let new_agg_group_keys_in_vec = agg
.group_key
.indices()
.map(|x| proj_o2i.map(x))
.collect_vec();
let new_agg_group_keys = IndexSet::from_iter(new_agg_group_keys_in_vec.clone());
let mut agg = agg;
agg.input = proj.input();
// modify agg calls according to projection
agg.agg_calls
.iter_mut()
.for_each(|x| x.rewrite_input_index(proj_o2i.clone()));
agg.group_key = new_agg_group_keys.clone();
agg.input = proj.input();
if new_agg_group_keys.to_vec() != new_agg_group_keys_in_vec {
// Need a project
let new_agg_group_keys_cardinality = new_agg_group_keys.len();
let out_col_idx = new_agg_group_keys_in_vec
.into_iter()
.map(|x| new_agg_group_keys.indices().position(|y| y == x).unwrap())
.chain(
new_agg_group_keys_cardinality
..new_agg_group_keys_cardinality + agg.agg_calls.len(),
);
Some(LogicalProject::with_out_col_idx(agg.into(), out_col_idx).into())
} else {
Some(agg.into())
}
}
}
impl AggProjectMergeRule {
pub fn create() -> BoxedRule {
Box::new(AggProjectMergeRule {})
}
}