risingwave_frontend/optimizer/rule/
over_window_to_agg_and_join_rule.rs1use itertools::Itertools;
16use risingwave_expr::window_function::WindowFuncKind;
17use risingwave_pb::expr::expr_node::Type;
18use risingwave_pb::plan_common::JoinType;
19
20use super::{BoxedRule, Rule};
21use crate::PlanRef;
22use crate::expr::{AggCall, ExprImpl, FunctionCall, InputRef, OrderBy};
23use crate::optimizer::plan_node::{
24 LogicalAgg, LogicalJoin, LogicalProject, LogicalShare, PlanTreeNodeUnary,
25};
26use crate::utils::{Condition, GroupBy};
27pub struct OverWindowToAggAndJoinRule;
28
29impl OverWindowToAggAndJoinRule {
30 pub fn create() -> BoxedRule {
31 Box::new(OverWindowToAggAndJoinRule)
32 }
33}
34
35impl Rule for OverWindowToAggAndJoinRule {
36 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
37 let over_window = plan.as_logical_over_window()?;
38 let window_functions = over_window.window_functions();
39 if window_functions.iter().any(|window| {
40 !(window.order_by.is_empty()
41 && window.frame.bounds.start_is_unbounded()
42 && window.frame.bounds.end_is_unbounded())
43 }) {
44 return None;
45 }
46 let group_exprs: Vec<ExprImpl> = window_functions[0]
48 .partition_by
49 .iter()
50 .map(|x| x.clone().into())
51 .collect_vec();
52 let mut select_exprs = group_exprs.clone();
53 for func in window_functions {
54 if let WindowFuncKind::Aggregate(kind) = &func.kind {
55 let agg_call = AggCall::new(
56 kind.clone(),
57 func.args.iter().map(|x| x.clone().into()).collect_vec(),
58 false,
59 OrderBy::any(),
60 Condition::true_cond(),
61 vec![],
62 )
63 .ok()?;
64 select_exprs.push(agg_call.into());
65 } else {
66 return None;
67 }
68 }
69
70 let input_len = over_window.input().schema().len();
71 let mut out_fields = (0..input_len).collect_vec();
72 for i in 0..window_functions.len() {
73 out_fields.push(input_len + group_exprs.len() + i);
74 }
75 let common_input = LogicalShare::create(over_window.input());
76 let (agg, ..) = LogicalAgg::create(
77 select_exprs,
78 GroupBy::GroupKey(group_exprs),
79 None,
80 common_input.clone(),
81 )
82 .ok()?;
83 let on_clause = window_functions[0].partition_by.iter().enumerate().fold(
84 Condition::true_cond(),
85 |on_clause, (idx, x)| {
86 on_clause.and(Condition::with_expr(
87 FunctionCall::new(
88 Type::Equal,
89 vec![
90 x.clone().into(),
91 InputRef::new(idx + input_len, x.data_type.clone()).into(),
92 ],
93 )
94 .unwrap()
95 .into(),
96 ))
97 },
98 );
99 Some(
100 LogicalProject::with_out_col_idx(
101 LogicalJoin::new(common_input, agg, JoinType::Inner, on_clause).into(),
102 out_fields.into_iter(),
103 )
104 .into(),
105 )
106 }
107}