risingwave_frontend/optimizer/rule/
over_window_to_agg_and_join_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_expr::window_function::WindowFuncKind;
17use risingwave_pb::expr::expr_node::Type;
18use risingwave_pb::plan_common::JoinType;
19
20use super::{BoxedRule, Rule};
21use crate::PlanRef;
22use crate::expr::{AggCall, ExprImpl, FunctionCall, InputRef, OrderBy};
23use crate::optimizer::plan_node::{
24    LogicalAgg, LogicalJoin, LogicalProject, LogicalShare, PlanTreeNodeUnary,
25};
26use crate::utils::{Condition, GroupBy};
27pub struct OverWindowToAggAndJoinRule;
28
29impl OverWindowToAggAndJoinRule {
30    pub fn create() -> BoxedRule {
31        Box::new(OverWindowToAggAndJoinRule)
32    }
33}
34
35impl Rule for OverWindowToAggAndJoinRule {
36    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
37        let over_window = plan.as_logical_over_window()?;
38        let window_functions = over_window.window_functions();
39        if window_functions.iter().any(|window| {
40            !(window.order_by.is_empty()
41                && window.frame.bounds.start_is_unbounded()
42                && window.frame.bounds.end_is_unbounded())
43        }) {
44            return None;
45        }
46        // This rule should be applied after OverWindowSplitByWindowRule.
47        let group_exprs: Vec<ExprImpl> = window_functions[0]
48            .partition_by
49            .iter()
50            .map(|x| x.clone().into())
51            .collect_vec();
52        let mut select_exprs = group_exprs.clone();
53        for func in window_functions {
54            if let WindowFuncKind::Aggregate(kind) = &func.kind {
55                let agg_call = AggCall::new(
56                    kind.clone(),
57                    func.args.iter().map(|x| x.clone().into()).collect_vec(),
58                    false,
59                    OrderBy::any(),
60                    Condition::true_cond(),
61                    vec![],
62                )
63                .ok()?;
64                select_exprs.push(agg_call.into());
65            } else {
66                return None;
67            }
68        }
69
70        let input_len = over_window.input().schema().len();
71        let mut out_fields = (0..input_len).collect_vec();
72        for i in 0..window_functions.len() {
73            out_fields.push(input_len + group_exprs.len() + i);
74        }
75        let common_input = LogicalShare::create(over_window.input());
76        let (agg, ..) = LogicalAgg::create(
77            select_exprs,
78            GroupBy::GroupKey(group_exprs),
79            None,
80            common_input.clone(),
81        )
82        .ok()?;
83        let on_clause = window_functions[0].partition_by.iter().enumerate().fold(
84            Condition::true_cond(),
85            |on_clause, (idx, x)| {
86                on_clause.and(Condition::with_expr(
87                    FunctionCall::new(
88                        Type::Equal,
89                        vec![
90                            x.clone().into(),
91                            InputRef::new(idx + input_len, x.data_type.clone()).into(),
92                        ],
93                    )
94                    .unwrap()
95                    .into(),
96                ))
97            },
98        );
99        Some(
100            LogicalProject::with_out_col_idx(
101                LogicalJoin::new(common_input, agg, JoinType::Inner, on_clause).into(),
102                out_fields.into_iter(),
103            )
104            .into(),
105        )
106    }
107}