risingwave_frontend/optimizer/rule/
common_sub_expr_extract_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16
17use super::super::plan_node::*;
18use super::{BoxedRule, Rule};
19use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor, InputRef};
20use crate::optimizer::plan_expr_rewriter::CseRewriter;
21use crate::optimizer::plan_expr_visitor::CseExprCounter;
22use crate::optimizer::plan_node::generic::GenericPlanRef;
23
24pub struct CommonSubExprExtractRule {}
25impl Rule for CommonSubExprExtractRule {
26    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
27        let project: &LogicalProject = plan.as_logical_project()?;
28
29        let mut expr_counter = CseExprCounter::default();
30        for expr in project.exprs() {
31            expr_counter.visit_expr(expr);
32        }
33
34        if expr_counter.counter.values().all(|counter| *counter <= 1) {
35            return None;
36        }
37
38        let (exprs, input) = project.clone().decompose();
39        let input_schema_len = input.schema().len();
40        let mut cse_rewriter = CseRewriter::new(expr_counter, input_schema_len);
41        let top_project_exprs = exprs
42            .into_iter()
43            .map(|expr| cse_rewriter.rewrite_expr(expr))
44            .collect_vec();
45        let bottom_project_exprs = {
46            let mut exprs = Vec::with_capacity(input_schema_len + cse_rewriter.cse_mapping.len());
47            for (i, field) in input.schema().fields.iter().enumerate() {
48                let expr = ExprImpl::InputRef(InputRef::new(i, field.data_type.clone()).into());
49                exprs.push(expr);
50            }
51            exprs.extend(
52                cse_rewriter
53                    .cse_mapping
54                    .into_iter()
55                    .sorted_by(|(_, v1), (_, v2)| Ord::cmp(&v1.index, &v2.index))
56                    .map(|(k, _)| ExprImpl::FunctionCall(k.into())),
57            );
58            exprs
59        };
60        let bottom_project = LogicalProject::new(input, bottom_project_exprs);
61        let top_project = LogicalProject::new(bottom_project.into(), top_project_exprs);
62        Some(top_project.into())
63    }
64}
65
66impl CommonSubExprExtractRule {
67    pub fn create() -> BoxedRule {
68        Box::new(CommonSubExprExtractRule {})
69    }
70}