risingwave_frontend/optimizer/rule/
over_window_to_agg_and_join_rule.rs1use itertools::Itertools;
16use risingwave_expr::window_function::WindowFuncKind;
17use risingwave_pb::expr::expr_node::Type;
18use risingwave_pb::plan_common::JoinType;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{AggCall, ExprImpl, FunctionCall, InputRef, OrderBy};
22use crate::optimizer::plan_node::{
23 LogicalAgg, LogicalJoin, LogicalProject, LogicalShare, PlanTreeNodeUnary,
24};
25use crate::utils::{Condition, GroupBy};
26pub struct OverWindowToAggAndJoinRule;
27
28impl OverWindowToAggAndJoinRule {
29 pub fn create() -> BoxedRule {
30 Box::new(OverWindowToAggAndJoinRule)
31 }
32}
33
34impl Rule<Logical> for OverWindowToAggAndJoinRule {
35 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
36 let over_window = plan.as_logical_over_window()?;
37 let window_functions = over_window.window_functions();
38 if window_functions.iter().any(|window| {
39 !(window.order_by.is_empty()
40 && window.frame.bounds.start_is_unbounded()
41 && window.frame.bounds.end_is_unbounded())
42 }) {
43 return None;
44 }
45 let group_exprs: Vec<ExprImpl> = window_functions[0]
47 .partition_by
48 .iter()
49 .map(|x| x.clone().into())
50 .collect_vec();
51 let mut select_exprs = group_exprs.clone();
52 for func in window_functions {
53 if let WindowFuncKind::Aggregate(kind) = &func.kind {
54 let agg_call = AggCall::new(
55 kind.clone(),
56 func.args.iter().map(|x| x.clone().into()).collect_vec(),
57 false,
58 OrderBy::any(),
59 Condition::true_cond(),
60 vec![],
61 )
62 .ok()?;
63 select_exprs.push(agg_call.into());
64 } else {
65 return None;
66 }
67 }
68
69 let input_len = over_window.input().schema().len();
70 let mut out_fields = (0..input_len).collect_vec();
71 for i in 0..window_functions.len() {
72 out_fields.push(input_len + group_exprs.len() + i);
73 }
74 let common_input = LogicalShare::create(over_window.input());
75 let (agg, ..) = LogicalAgg::create(
76 select_exprs,
77 GroupBy::GroupKey(group_exprs),
78 None,
79 common_input.clone(),
80 )
81 .ok()?;
82 let on_clause = window_functions[0].partition_by.iter().enumerate().fold(
83 Condition::true_cond(),
84 |on_clause, (idx, x)| {
85 on_clause.and(Condition::with_expr(
86 FunctionCall::new(
87 Type::Equal,
88 vec![
89 x.clone().into(),
90 InputRef::new(idx + input_len, x.data_type.clone()).into(),
91 ],
92 )
93 .unwrap()
94 .into(),
95 ))
96 },
97 );
98 Some(
99 LogicalProject::with_out_col_idx(
100 LogicalJoin::new(common_input, agg, JoinType::Inner, on_clause).into(),
101 out_fields.into_iter(),
102 )
103 .into(),
104 )
105 }
106}