risingwave_frontend/optimizer/plan_expr_rewriter/
cse_rewriter.rs1use std::collections::HashMap;
16
17use crate::expr::{Expr, ExprImpl, ExprRewriter, FunctionCall, InputRef};
18use crate::optimizer::plan_expr_visitor::CseExprCounter;
19
20#[derive(Default)]
21pub struct CseRewriter {
22 pub expr_counter: CseExprCounter,
23 pub cse_input_ref_offset: usize,
24 pub cse_mapping: HashMap<FunctionCall, InputRef>,
25}
26
27impl CseRewriter {
28 pub fn new(expr_counter: CseExprCounter, cse_input_ref_offset: usize) -> Self {
29 Self {
30 expr_counter,
31 cse_input_ref_offset,
32 cse_mapping: HashMap::default(),
33 }
34 }
35}
36
37impl ExprRewriter for CseRewriter {
38 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
39 if let Some(count) = self.expr_counter.counter.get(&func_call)
40 && *count > 1
41 {
42 if let Some(expr) = self.cse_mapping.get(&func_call) {
43 let expr: ExprImpl = ExprImpl::InputRef(expr.clone().into());
44 return expr;
45 }
46 let input_ref = InputRef::new(self.cse_input_ref_offset, func_call.return_type());
47 self.cse_input_ref_offset += 1;
48 self.cse_mapping.insert(func_call, input_ref.clone());
49 return ExprImpl::InputRef(input_ref.into());
50 }
51
52 let (func_type, inputs, ret) = func_call.decompose();
53 let inputs = inputs
54 .into_iter()
55 .map(|expr| self.rewrite_expr(expr))
56 .collect();
57 FunctionCall::new_unchecked(func_type, inputs, ret).into()
58 }
59}