risingwave_frontend/optimizer/plan_node/col_pruning.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::{HashMap, HashSet};
use paste::paste;
use super::*;
use crate::optimizer::plan_visitor::ShareParentCounter;
use crate::optimizer::PlanVisitor;
use crate::{for_batch_plan_nodes, for_stream_plan_nodes};
/// The trait for column pruning, only logical plan node will use it, though all plan node impl it.
pub trait ColPrunable {
/// Transform the plan node to only output the required columns ordered by index number.
///
/// `required_cols` must be a subset of the range `0..self.schema().len()`.
///
/// After calling `prune_col` on the children, their output schema may change, so
/// the caller may need to transform its [`InputRef`](crate::expr::InputRef) using
/// [`ColIndexMapping`](crate::utils::ColIndexMapping).
///
/// When implementing this method for a node, it may require its children to produce additional
/// columns besides `required_cols`. In this case, it may need to insert a
/// [`LogicalProject`](super::LogicalProject) above to have a correct schema.
fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef;
}
/// Implements [`ColPrunable`] for batch and streaming node.
macro_rules! impl_prune_col {
($( { $convention:ident, $name:ident }),*) => {
paste!{
$(impl ColPrunable for [<$convention $name>] {
fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
panic!("column pruning is only allowed on logical plan")
}
})*
}
}
}
for_batch_plan_nodes! { impl_prune_col }
for_stream_plan_nodes! { impl_prune_col }
#[derive(Debug, Clone)]
pub struct ColumnPruningContext {
/// `share_required_cols_map` is used by the first round of column pruning to keep track of
/// each parent required columns.
share_required_cols_map: HashMap<PlanNodeId, Vec<Vec<usize>>>,
/// Used to calculate how many parents the share operator has.
share_parent_counter: ShareParentCounter,
/// Share input cache used by the second round of column pruning.
/// For a DAG plan, use only one round to prune column is not enough,
/// because we need to change the schema of share operator
/// and you don't know what is the final schema when the first parent try to prune column,
/// so we need a second round to use the information collected by the first round.
/// `share_cache` maps original share operator plan id to the new share operator and the column
/// changed mapping which is actually the merged required columns calculated at the first
/// round.
share_cache: HashMap<PlanNodeId, (PlanRef, Vec<usize>)>,
/// `share_visited` is used to track whether the share operator is visited, because we need to
/// recursively call the `prune_col` of the new share operator to trigger the replacement.
/// It is only used at the second round of the column pruning.
share_visited: HashSet<PlanNodeId>,
}
impl ColumnPruningContext {
pub fn new(root: PlanRef) -> Self {
let mut share_parent_counter = ShareParentCounter::default();
share_parent_counter.visit(root);
Self {
share_required_cols_map: Default::default(),
share_parent_counter,
share_cache: Default::default(),
share_visited: Default::default(),
}
}
pub fn get_parent_num(&self, share: &LogicalShare) -> usize {
self.share_parent_counter.get_parent_num(share)
}
pub fn add_required_cols(
&mut self,
plan_node_id: PlanNodeId,
required_cols: Vec<usize>,
) -> usize {
self.share_required_cols_map
.entry(plan_node_id)
.and_modify(|e| e.push(required_cols.clone()))
.or_insert_with(|| vec![required_cols])
.len()
}
pub fn take_required_cols(&mut self, plan_node_id: PlanNodeId) -> Option<Vec<Vec<usize>>> {
self.share_required_cols_map.remove(&plan_node_id)
}
pub fn add_share_cache(
&mut self,
plan_node_id: PlanNodeId,
new_share: PlanRef,
merged_required_columns: Vec<usize>,
) {
self.share_cache
.try_insert(plan_node_id, (new_share, merged_required_columns))
.unwrap();
}
pub fn get_share_cache(&self, plan_node_id: PlanNodeId) -> Option<(PlanRef, Vec<usize>)> {
self.share_cache.get(&plan_node_id).cloned()
}
pub fn need_second_round(&self) -> bool {
!self.share_cache.is_empty()
}
pub fn visit_share_at_second_round(&mut self, plan_node_id: PlanNodeId) -> bool {
self.share_visited.insert(plan_node_id)
}
}