risingwave_frontend/optimizer/rule/
over_window_to_agg_and_join_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_expr::window_function::WindowFuncKind;
17use risingwave_pb::expr::expr_node::Type;
18use risingwave_pb::plan_common::JoinType;
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{AggCall, ExprImpl, FunctionCall, InputRef, OrderBy};
22use crate::optimizer::plan_node::{
23    LogicalAgg, LogicalJoin, LogicalProject, LogicalShare, PlanTreeNodeUnary,
24};
25use crate::utils::{Condition, GroupBy};
26pub struct OverWindowToAggAndJoinRule;
27
28impl OverWindowToAggAndJoinRule {
29    pub fn create() -> BoxedRule {
30        Box::new(OverWindowToAggAndJoinRule)
31    }
32}
33
34impl Rule<Logical> for OverWindowToAggAndJoinRule {
35    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
36        let over_window = plan.as_logical_over_window()?;
37        let window_functions = over_window.window_functions();
38        if window_functions.iter().any(|window| {
39            !(window.order_by.is_empty()
40                && window.frame.bounds.start_is_unbounded()
41                && window.frame.bounds.end_is_unbounded())
42        }) {
43            return None;
44        }
45        // This rule should be applied after OverWindowSplitByWindowRule.
46        let group_exprs: Vec<ExprImpl> = window_functions[0]
47            .partition_by
48            .iter()
49            .map(|x| x.clone().into())
50            .collect_vec();
51        let mut select_exprs = group_exprs.clone();
52        for func in window_functions {
53            if let WindowFuncKind::Aggregate(kind) = &func.kind {
54                let agg_call = AggCall::new(
55                    kind.clone(),
56                    func.args.iter().map(|x| x.clone().into()).collect_vec(),
57                    false,
58                    OrderBy::any(),
59                    Condition::true_cond(),
60                    vec![],
61                )
62                .ok()?;
63                select_exprs.push(agg_call.into());
64            } else {
65                return None;
66            }
67        }
68
69        let input_len = over_window.input().schema().len();
70        let mut out_fields = (0..input_len).collect_vec();
71        for i in 0..window_functions.len() {
72            out_fields.push(input_len + group_exprs.len() + i);
73        }
74        let common_input = LogicalShare::create(over_window.input());
75        let (agg, ..) = LogicalAgg::create(
76            select_exprs,
77            GroupBy::GroupKey(group_exprs),
78            None,
79            common_input.clone(),
80        )
81        .ok()?;
82        let on_clause = window_functions[0].partition_by.iter().enumerate().fold(
83            Condition::true_cond(),
84            |on_clause, (idx, x)| {
85                on_clause.and(Condition::with_expr(
86                    FunctionCall::new(
87                        Type::Equal,
88                        vec![
89                            x.clone().into(),
90                            InputRef::new(idx + input_len, x.data_type.clone()).into(),
91                        ],
92                    )
93                    .unwrap()
94                    .into(),
95                ))
96            },
97        );
98        Some(
99            LogicalProject::with_out_col_idx(
100                LogicalJoin::new(common_input, agg, JoinType::Inner, on_clause).into(),
101                out_fields.into_iter(),
102            )
103            .into(),
104        )
105    }
106}