risingwave_frontend/optimizer/plan_node/col_pruning.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use super::*;
18use crate::optimizer::plan_visitor::ShareParentCounter;
19use crate::optimizer::{LogicalPlanRef as PlanRef, PlanVisitor};
20
21/// The trait for column pruning, only logical plan node will use it, though all plan node impl it.
22pub trait ColPrunable {
23 /// Transform the plan node to only output the required columns ordered by index number.
24 ///
25 /// `required_cols` must be a subset of the range `0..self.schema().len()`.
26 ///
27 /// After calling `prune_col` on the children, their output schema may change, so
28 /// the caller may need to transform its [`InputRef`](crate::expr::InputRef) using
29 /// [`ColIndexMapping`](crate::utils::ColIndexMapping).
30 ///
31 /// When implementing this method for a node, it may require its children to produce additional
32 /// columns besides `required_cols`. In this case, it may need to insert a
33 /// [`LogicalProject`](super::LogicalProject) above to have a correct schema.
34 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef;
35}
36
37#[derive(Debug, Clone)]
38pub struct ColumnPruningContext {
39 /// `share_required_cols_map` is used by the first round of column pruning to keep track of
40 /// each parent required columns.
41 share_required_cols_map: HashMap<PlanNodeId, Vec<Vec<usize>>>,
42 /// Used to calculate how many parents the share operator has.
43 share_parent_counter: ShareParentCounter,
44 /// Share input cache used by the second round of column pruning.
45 /// For a DAG plan, use only one round to prune column is not enough,
46 /// because we need to change the schema of share operator
47 /// and you don't know what is the final schema when the first parent try to prune column,
48 /// so we need a second round to use the information collected by the first round.
49 /// `share_cache` maps original share operator plan id to the new share operator and the column
50 /// changed mapping which is actually the merged required columns calculated at the first
51 /// round.
52 share_cache: HashMap<PlanNodeId, (PlanRef, Vec<usize>)>,
53 /// `share_visited` is used to track whether the share operator is visited, because we need to
54 /// recursively call the `prune_col` of the new share operator to trigger the replacement.
55 /// It is only used at the second round of the column pruning.
56 share_visited: HashSet<PlanNodeId>,
57}
58
59impl ColumnPruningContext {
60 pub fn new(root: PlanRef) -> Self {
61 let mut share_parent_counter = ShareParentCounter::default();
62 share_parent_counter.visit(root);
63 Self {
64 share_required_cols_map: Default::default(),
65 share_parent_counter,
66 share_cache: Default::default(),
67 share_visited: Default::default(),
68 }
69 }
70
71 pub fn get_parent_num(&self, share: &LogicalShare) -> usize {
72 self.share_parent_counter.get_parent_num(share)
73 }
74
75 pub fn add_required_cols(
76 &mut self,
77 plan_node_id: PlanNodeId,
78 required_cols: Vec<usize>,
79 ) -> usize {
80 self.share_required_cols_map
81 .entry(plan_node_id)
82 .and_modify(|e| e.push(required_cols.clone()))
83 .or_insert_with(|| vec![required_cols])
84 .len()
85 }
86
87 pub fn take_required_cols(&mut self, plan_node_id: PlanNodeId) -> Option<Vec<Vec<usize>>> {
88 self.share_required_cols_map.remove(&plan_node_id)
89 }
90
91 pub fn add_share_cache(
92 &mut self,
93 plan_node_id: PlanNodeId,
94 new_share: PlanRef,
95 merged_required_columns: Vec<usize>,
96 ) {
97 self.share_cache
98 .try_insert(plan_node_id, (new_share, merged_required_columns))
99 .unwrap();
100 }
101
102 pub fn get_share_cache(&self, plan_node_id: PlanNodeId) -> Option<(PlanRef, Vec<usize>)> {
103 self.share_cache.get(&plan_node_id).cloned()
104 }
105
106 pub fn need_second_round(&self) -> bool {
107 !self.share_cache.is_empty()
108 }
109
110 pub fn visit_share_at_first_round(&mut self, plan_node_id: PlanNodeId) -> bool {
111 self.share_visited.insert(plan_node_id)
112 }
113}