risingwave_frontend/optimizer/plan_node/col_pruning.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use paste::paste;
18
19use super::*;
20use crate::optimizer::PlanVisitor;
21use crate::optimizer::plan_visitor::ShareParentCounter;
22use crate::{for_batch_plan_nodes, for_stream_plan_nodes};
23
24/// The trait for column pruning, only logical plan node will use it, though all plan node impl it.
25pub trait ColPrunable {
26 /// Transform the plan node to only output the required columns ordered by index number.
27 ///
28 /// `required_cols` must be a subset of the range `0..self.schema().len()`.
29 ///
30 /// After calling `prune_col` on the children, their output schema may change, so
31 /// the caller may need to transform its [`InputRef`](crate::expr::InputRef) using
32 /// [`ColIndexMapping`](crate::utils::ColIndexMapping).
33 ///
34 /// When implementing this method for a node, it may require its children to produce additional
35 /// columns besides `required_cols`. In this case, it may need to insert a
36 /// [`LogicalProject`](super::LogicalProject) above to have a correct schema.
37 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef;
38}
39
40/// Implements [`ColPrunable`] for batch and streaming node.
41macro_rules! impl_prune_col {
42 ($( { $convention:ident, $name:ident }),*) => {
43 paste!{
44 $(impl ColPrunable for [<$convention $name>] {
45 fn prune_col(&self, _required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef {
46 panic!("column pruning is only allowed on logical plan")
47 }
48 })*
49 }
50 }
51}
52for_batch_plan_nodes! { impl_prune_col }
53for_stream_plan_nodes! { impl_prune_col }
54
55#[derive(Debug, Clone)]
56pub struct ColumnPruningContext {
57 /// `share_required_cols_map` is used by the first round of column pruning to keep track of
58 /// each parent required columns.
59 share_required_cols_map: HashMap<PlanNodeId, Vec<Vec<usize>>>,
60 /// Used to calculate how many parents the share operator has.
61 share_parent_counter: ShareParentCounter,
62 /// Share input cache used by the second round of column pruning.
63 /// For a DAG plan, use only one round to prune column is not enough,
64 /// because we need to change the schema of share operator
65 /// and you don't know what is the final schema when the first parent try to prune column,
66 /// so we need a second round to use the information collected by the first round.
67 /// `share_cache` maps original share operator plan id to the new share operator and the column
68 /// changed mapping which is actually the merged required columns calculated at the first
69 /// round.
70 share_cache: HashMap<PlanNodeId, (PlanRef, Vec<usize>)>,
71 /// `share_visited` is used to track whether the share operator is visited, because we need to
72 /// recursively call the `prune_col` of the new share operator to trigger the replacement.
73 /// It is only used at the second round of the column pruning.
74 share_visited: HashSet<PlanNodeId>,
75}
76
77impl ColumnPruningContext {
78 pub fn new(root: PlanRef) -> Self {
79 let mut share_parent_counter = ShareParentCounter::default();
80 share_parent_counter.visit(root);
81 Self {
82 share_required_cols_map: Default::default(),
83 share_parent_counter,
84 share_cache: Default::default(),
85 share_visited: Default::default(),
86 }
87 }
88
89 pub fn get_parent_num(&self, share: &LogicalShare) -> usize {
90 self.share_parent_counter.get_parent_num(share)
91 }
92
93 pub fn add_required_cols(
94 &mut self,
95 plan_node_id: PlanNodeId,
96 required_cols: Vec<usize>,
97 ) -> usize {
98 self.share_required_cols_map
99 .entry(plan_node_id)
100 .and_modify(|e| e.push(required_cols.clone()))
101 .or_insert_with(|| vec![required_cols])
102 .len()
103 }
104
105 pub fn take_required_cols(&mut self, plan_node_id: PlanNodeId) -> Option<Vec<Vec<usize>>> {
106 self.share_required_cols_map.remove(&plan_node_id)
107 }
108
109 pub fn add_share_cache(
110 &mut self,
111 plan_node_id: PlanNodeId,
112 new_share: PlanRef,
113 merged_required_columns: Vec<usize>,
114 ) {
115 self.share_cache
116 .try_insert(plan_node_id, (new_share, merged_required_columns))
117 .unwrap();
118 }
119
120 pub fn get_share_cache(&self, plan_node_id: PlanNodeId) -> Option<(PlanRef, Vec<usize>)> {
121 self.share_cache.get(&plan_node_id).cloned()
122 }
123
124 pub fn need_second_round(&self) -> bool {
125 !self.share_cache.is_empty()
126 }
127
128 pub fn visit_share_at_second_round(&mut self, plan_node_id: PlanNodeId) -> bool {
129 self.share_visited.insert(plan_node_id)
130 }
131}