risingwave_frontend/optimizer/rule/
common_sub_expr_extract_rule.rs1use itertools::Itertools;
16
17use super::super::plan_node::*;
18use super::{BoxedRule, Rule};
19use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor, InputRef};
20use crate::optimizer::plan_expr_rewriter::CseRewriter;
21use crate::optimizer::plan_expr_visitor::CseExprCounter;
22use crate::optimizer::plan_node::generic::GenericPlanRef;
23
24pub struct CommonSubExprExtractRule {}
25impl Rule for CommonSubExprExtractRule {
26 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
27 let project: &LogicalProject = plan.as_logical_project()?;
28
29 let mut expr_counter = CseExprCounter::default();
30 for expr in project.exprs() {
31 expr_counter.visit_expr(expr);
32 }
33
34 if expr_counter.counter.values().all(|counter| *counter <= 1) {
35 return None;
36 }
37
38 let (exprs, input) = project.clone().decompose();
39 let input_schema_len = input.schema().len();
40 let mut cse_rewriter = CseRewriter::new(expr_counter, input_schema_len);
41 let top_project_exprs = exprs
42 .into_iter()
43 .map(|expr| cse_rewriter.rewrite_expr(expr))
44 .collect_vec();
45 let bottom_project_exprs = {
46 let mut exprs = Vec::with_capacity(input_schema_len + cse_rewriter.cse_mapping.len());
47 for (i, field) in input.schema().fields.iter().enumerate() {
48 let expr = ExprImpl::InputRef(InputRef::new(i, field.data_type.clone()).into());
49 exprs.push(expr);
50 }
51 exprs.extend(
52 cse_rewriter
53 .cse_mapping
54 .into_iter()
55 .sorted_by(|(_, v1), (_, v2)| Ord::cmp(&v1.index, &v2.index))
56 .map(|(k, _)| ExprImpl::FunctionCall(k.into())),
57 );
58 exprs
59 };
60 let bottom_project = LogicalProject::new(input, bottom_project_exprs);
61 let top_project = LogicalProject::new(bottom_project.into(), top_project_exprs);
62 Some(top_project.into())
63 }
64}
65
66impl CommonSubExprExtractRule {
67 pub fn create() -> BoxedRule {
68 Box::new(CommonSubExprExtractRule {})
69 }
70}