risingwave_frontend/optimizer/plan_rewriter/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod plan_cloner;
16mod share_source_rewriter;
17
18use std::collections::HashMap;
19
20use itertools::Itertools;
21pub use plan_cloner::*;
22pub use share_source_rewriter::*;
23
24use crate::optimizer::plan_node::generic::GenericPlanRef;
25use crate::optimizer::plan_node::*;
26
27pub trait PlanRewriter<C: ConventionMarker> {
28    fn rewrite_with_inputs(&mut self, plan: &PlanRef<C>, inputs: Vec<PlanRef<C>>) -> PlanRef<C>;
29}
30
31impl<C: ConventionMarker> PlanRef<C> {
32    pub fn rewrite_with(&self, rewriter: &mut impl PlanRewriter<C>) -> PlanRef<C> {
33        let mut share_map = HashMap::new();
34        self.rewrite_recursively(rewriter, &mut share_map)
35    }
36
37    fn rewrite_recursively(
38        &self,
39        rewriter: &mut impl PlanRewriter<C>,
40        share_map: &mut HashMap<PlanNodeId, PlanRef<C>>,
41    ) -> PlanRef<C> {
42        use risingwave_common::util::recursive::{Recurse, tracker};
43
44        use crate::session::current::notice_to_user;
45        tracker!().recurse(|t| {
46            if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
47                notice_to_user(PLAN_TOO_DEEP_NOTICE);
48            }
49
50            if let Some(share) = self.as_share_node() {
51                let id = share.plan_base().id();
52                return if let Some(share) = share_map.get(&id) {
53                    share.clone()
54                } else {
55                    let input = share.input();
56                    let new_input = input.rewrite_recursively(rewriter, share_map);
57                    let new_plan = C::ShareNode::new_share(generic::Share::new(new_input));
58                    share_map
59                        .try_insert(id, new_plan.clone())
60                        .expect("non-duplicate");
61                    new_plan
62                };
63            }
64
65            let inputs = self
66                .inputs()
67                .iter()
68                .map(|plan| plan.rewrite_recursively(rewriter, share_map))
69                .collect_vec();
70            rewriter.rewrite_with_inputs(self, inputs)
71        })
72    }
73}