risingwave_frontend/optimizer/plan_node/
col_pruning.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use super::*;
18use crate::optimizer::plan_visitor::ShareParentCounter;
19use crate::optimizer::{LogicalPlanRef as PlanRef, PlanVisitor};
20
21/// The trait for column pruning, only logical plan node will use it, though all plan node impl it.
22pub trait ColPrunable {
23    /// Transform the plan node to only output the required columns ordered by index number.
24    ///
25    /// `required_cols` must be a subset of the range `0..self.schema().len()`.
26    ///
27    /// After calling `prune_col` on the children, their output schema may change, so
28    /// the caller may need to transform its [`InputRef`](crate::expr::InputRef) using
29    /// [`ColIndexMapping`](crate::utils::ColIndexMapping).
30    ///
31    /// When implementing this method for a node, it may require its children to produce additional
32    /// columns besides `required_cols`. In this case, it may need to insert a
33    /// [`LogicalProject`](super::LogicalProject) above to have a correct schema.
34    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef;
35}
36
37#[derive(Debug, Clone)]
38pub struct ColumnPruningContext {
39    /// `share_required_cols_map` is used by the first round of column pruning to keep track of
40    /// each parent required columns.
41    share_required_cols_map: HashMap<PlanNodeId, Vec<Vec<usize>>>,
42    /// Used to calculate how many parents the share operator has.
43    share_parent_counter: ShareParentCounter,
44    /// Share input cache used by the second round of column pruning.
45    /// For a DAG plan, use only one round to prune column is not enough,
46    /// because we need to change the schema of share operator
47    /// and you don't know what is the final schema when the first parent try to prune column,
48    /// so we need a second round to use the information collected by the first round.
49    /// `share_cache` maps original share operator plan id to the new share operator and the column
50    /// changed mapping which is actually the merged required columns calculated at the first
51    /// round.
52    share_cache: HashMap<PlanNodeId, (PlanRef, Vec<usize>)>,
53    /// `share_visited` is used to track whether the share operator is visited, because we need to
54    /// recursively call the `prune_col` of the new share operator to trigger the replacement.
55    /// It is only used at the second round of the column pruning.
56    share_visited: HashSet<PlanNodeId>,
57}
58
59impl ColumnPruningContext {
60    pub fn new(root: PlanRef) -> Self {
61        let mut share_parent_counter = ShareParentCounter::default();
62        share_parent_counter.visit(root);
63        Self {
64            share_required_cols_map: Default::default(),
65            share_parent_counter,
66            share_cache: Default::default(),
67            share_visited: Default::default(),
68        }
69    }
70
71    pub fn get_parent_num(&self, share: &LogicalShare) -> usize {
72        self.share_parent_counter.get_parent_num(share)
73    }
74
75    pub fn add_required_cols(
76        &mut self,
77        plan_node_id: PlanNodeId,
78        required_cols: Vec<usize>,
79    ) -> usize {
80        self.share_required_cols_map
81            .entry(plan_node_id)
82            .and_modify(|e| e.push(required_cols.clone()))
83            .or_insert_with(|| vec![required_cols])
84            .len()
85    }
86
87    pub fn take_required_cols(&mut self, plan_node_id: PlanNodeId) -> Option<Vec<Vec<usize>>> {
88        self.share_required_cols_map.remove(&plan_node_id)
89    }
90
91    pub fn add_share_cache(
92        &mut self,
93        plan_node_id: PlanNodeId,
94        new_share: PlanRef,
95        merged_required_columns: Vec<usize>,
96    ) {
97        self.share_cache
98            .try_insert(plan_node_id, (new_share, merged_required_columns))
99            .unwrap();
100    }
101
102    pub fn get_share_cache(&self, plan_node_id: PlanNodeId) -> Option<(PlanRef, Vec<usize>)> {
103        self.share_cache.get(&plan_node_id).cloned()
104    }
105
106    pub fn need_second_round(&self) -> bool {
107        !self.share_cache.is_empty()
108    }
109
110    pub fn visit_share_at_first_round(&mut self, plan_node_id: PlanNodeId) -> bool {
111        self.share_visited.insert(plan_node_id)
112    }
113}