risingwave_frontend/optimizer/rule/
mv_selection_rule.rs

1// Copyright 2026 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use super::prelude::{PlanRef, *};
18use crate::optimizer::optimizer_context::MaterializedViewCandidate;
19use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef, PlanAggCall};
20use crate::optimizer::plan_node::{
21    LogicalAgg, LogicalPlanRef, LogicalProject, LogicalScan, PlanTreeNodeUnary,
22};
23use crate::optimizer::rule::{BoxedRule, Rule};
24use crate::utils::{ColIndexMapping, IndexSet};
25
26pub struct MvSelectionRule;
27
28impl MvSelectionRule {
29    pub fn create() -> BoxedRule<Logical> {
30        Box::new(Self)
31    }
32}
33
34impl Rule<Logical> for MvSelectionRule {
35    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
36        let ctx = plan.ctx();
37        for candidate in ctx.batch_mview_candidates().iter() {
38            if let Some(rewritten) = Self::rewrite_with_candidate(&plan, candidate) {
39                return Some(rewritten);
40            }
41        }
42        None
43    }
44}
45
46impl MvSelectionRule {
47    fn rewrite_with_candidate(
48        plan: &PlanRef,
49        candidate: &MaterializedViewCandidate,
50    ) -> Option<PlanRef> {
51        if plan == &candidate.plan {
52            return Some(Self::candidate_scan(candidate, plan.ctx())?.into());
53        }
54        Self::agg_rollup_rewrite(plan, candidate)
55    }
56
57    fn candidate_scan(
58        candidate: &MaterializedViewCandidate,
59        ctx: crate::optimizer::optimizer_context::OptimizerContextRef,
60    ) -> Option<LogicalScan> {
61        let output_len = candidate.plan.schema().len();
62        debug_assert!(output_len <= candidate.table.columns().len());
63        let output_col_idx = (0..output_len).collect();
64        Some(
65            LogicalScan::create(candidate.table.clone(), ctx, None)
66                .clone_with_output_indices(output_col_idx),
67        )
68    }
69
70    /// Rewrite aggregate query by rolling up aggregate MV states.
71    ///
72    /// Graph (matching + rewrite):
73    /// ```text
74    /// 1) Normalize both sides to base-column lineage
75    ///
76    /// Query side                                  MV candidate side
77    /// Agg_q(Gq, Aq)                               [Project_mv]?
78    ///   └─ [Project_q]?                             └─ Agg_mv(Gm, Am)
79    ///       └─ BaseInput                                  └─ [Project_mv_in]?
80    ///                                                         └─ BaseInput
81    ///
82    /// 2) Match under lineage mapping
83    /// - same BaseInput
84    /// - map(Gq) ⊆ map(Gm)
85    /// - for each aq in Aq, find semantically equivalent am in Am
86    /// - aq must support partial_to_total
87    ///
88    /// 3) Build rewritten plan
89    /// Agg_rollup(group = mapped Gq, calls = total(aq) on MV columns)
90    ///   └─ Scan(MV table)
91    /// [Project] (optional, restore query output order)
92    /// ```
93    fn agg_rollup_rewrite(
94        plan: &PlanRef,
95        candidate: &MaterializedViewCandidate,
96    ) -> Option<PlanRef> {
97        let query_agg = plan.as_logical_agg()?;
98        let (mv_agg, mv_agg_to_mv_output) = Self::extract_mv_candidate_agg(candidate)?;
99
100        if !query_agg.grouping_sets().is_empty() || !mv_agg.grouping_sets().is_empty() {
101            return None;
102        }
103        // DISTINCT aggregates are not composable via partial-to-total rollup.
104        if query_agg.agg_calls().iter().any(|call| call.distinct)
105            || mv_agg.agg_calls().iter().any(|call| call.distinct)
106        {
107            return None;
108        }
109
110        let (query_base_input, query_input_to_base) =
111            Self::agg_input_to_base_mapping(query_agg.clone())?;
112        let (mv_base_input, mv_input_to_base) = Self::agg_input_to_base_mapping(mv_agg.clone())?;
113        if query_base_input != mv_base_input {
114            return None;
115        }
116
117        let mut mv_group_key_pos = HashMap::new();
118        for (pos, col_idx) in mv_agg.group_key().indices().enumerate() {
119            let mv_output_col_idx = *mv_agg_to_mv_output.get(pos)?;
120            mv_group_key_pos.insert(*mv_input_to_base.get(col_idx)?, mv_output_col_idx);
121        }
122
123        let mut query_group_key_in_mv_output = Vec::with_capacity(query_agg.group_key().len());
124        for col_idx in query_agg.group_key().indices() {
125            let base_col_idx = *query_input_to_base.get(col_idx)?;
126            query_group_key_in_mv_output.push(*mv_group_key_pos.get(&base_col_idx)?);
127        }
128
129        let mut normalized_call_to_mv_idx: HashMap<PlanAggCall, usize> = HashMap::new();
130        let mut rewritten_agg_calls = Vec::with_capacity(query_agg.agg_calls().len());
131        for query_call in query_agg.agg_calls() {
132            let normalized_query_call = Self::normalize_agg_call(query_call, &query_input_to_base)?;
133            let mv_call_idx = if let Some(mv_call_idx) = normalized_call_to_mv_idx
134                .get(&normalized_query_call)
135                .copied()
136            {
137                mv_call_idx
138            } else {
139                let mv_call_idx = mv_agg
140                    .agg_calls()
141                    .iter()
142                    .enumerate()
143                    .find(|(_, mv_call)| {
144                        Self::normalize_agg_call(mv_call, &mv_input_to_base).is_some_and(
145                            |normalized_mv_call| normalized_mv_call == normalized_query_call,
146                        )
147                    })?
148                    .0;
149                normalized_call_to_mv_idx.insert(normalized_query_call.clone(), mv_call_idx);
150                mv_call_idx
151            };
152
153            // Only composable aggregate kinds can be rolled up from MV states.
154            query_call.agg_type.partial_to_total()?;
155            let mv_agg_output_col_idx = mv_agg.group_key().len() + mv_call_idx;
156            let mv_output_col_idx = *mv_agg_to_mv_output.get(mv_agg_output_col_idx)?;
157            rewritten_agg_calls.push(query_call.partial_to_total_agg_call(mv_output_col_idx));
158        }
159
160        let mv_scan: LogicalPlanRef = Self::candidate_scan(candidate, plan.ctx())?.into();
161
162        let rewritten_group_key = IndexSet::from_iter(query_group_key_in_mv_output.iter().copied());
163        let rewritten_agg: LogicalPlanRef =
164            Agg::new(rewritten_agg_calls, rewritten_group_key.clone(), mv_scan)
165                .with_enable_two_phase(query_agg.core().two_phase_agg_enabled())
166                .into();
167
168        let mut output_col_idx =
169            Vec::with_capacity(query_agg.group_key().len() + query_agg.agg_calls().len());
170        for group_col in query_group_key_in_mv_output {
171            let col_pos = rewritten_group_key
172                .indices()
173                .position(|idx| idx == group_col)
174                .expect("group key must exist");
175            output_col_idx.push(col_pos);
176        }
177        output_col_idx
178            .extend((0..query_agg.agg_calls().len()).map(|idx| rewritten_group_key.len() + idx));
179
180        if output_col_idx.iter().copied().eq(0..output_col_idx.len()) {
181            Some(rewritten_agg)
182        } else {
183            Some(LogicalProject::with_out_col_idx(rewritten_agg, output_col_idx.into_iter()).into())
184        }
185    }
186
187    fn agg_input_to_base_mapping(agg: LogicalAgg) -> Option<(PlanRef, Vec<usize>)> {
188        let agg_input = agg.input();
189        if let Some(proj) = agg_input.as_logical_project() {
190            let input_to_base = proj.try_as_projection()?;
191            if let Some(scan) = proj.input().as_logical_scan() {
192                let scan_base_columns = input_to_base
193                    .iter()
194                    .map(|idx| scan.output_col_idx().get(*idx).copied())
195                    .collect::<Option<Vec<usize>>>()?;
196
197                // Canonicalize scan shape to avoid false mismatch from different output column
198                // subsets.
199                let canonical_scan =
200                    scan.clone_with_output_indices((0..scan.table().columns().len()).collect());
201                Some((canonical_scan.into(), scan_base_columns))
202            } else {
203                Some((proj.input(), input_to_base))
204            }
205        } else if let Some(scan) = agg_input.as_logical_scan() {
206            let scan_base_columns = scan.output_col_idx().clone();
207            let canonical_scan =
208                scan.clone_with_output_indices((0..scan.table().columns().len()).collect());
209            Some((canonical_scan.into(), scan_base_columns))
210        } else {
211            None
212        }
213    }
214
215    fn normalize_agg_call(call: &PlanAggCall, input_to_base: &[usize]) -> Option<PlanAggCall> {
216        let mut normalized = call.clone();
217        let index_mapping = ColIndexMapping::new(
218            input_to_base
219                .iter()
220                .copied()
221                .map(Some)
222                .collect::<Vec<Option<usize>>>(),
223            input_to_base.iter().max().copied().unwrap_or(0) + 1,
224        );
225        normalized.rewrite_input_index(index_mapping);
226        Some(normalized)
227    }
228
229    fn extract_mv_candidate_agg(
230        candidate: &MaterializedViewCandidate,
231    ) -> Option<(LogicalAgg, Vec<usize>)> {
232        if let Some(mv_agg) = candidate.plan.as_logical_agg() {
233            return Some((mv_agg.clone(), (0..mv_agg.schema().len()).collect()));
234        }
235
236        let mv_project = candidate.plan.as_logical_project()?;
237        let mv_project_input = mv_project.input();
238        let mv_agg = mv_project_input.as_logical_agg()?;
239        let mv_output_to_mv_agg = mv_project.try_as_projection()?;
240
241        let mut mv_agg_to_mv_output = vec![None; mv_agg.schema().len()];
242        for (mv_output_idx, mv_agg_idx) in mv_output_to_mv_agg.into_iter().enumerate() {
243            if mv_agg_idx >= mv_agg_to_mv_output.len() {
244                return None;
245            }
246            if mv_agg_to_mv_output[mv_agg_idx].is_none() {
247                // Keep the first projected output for each agg column.
248                // Duplicate projections (e.g. s1/s2 from the same agg output) are valid.
249                mv_agg_to_mv_output[mv_agg_idx] = Some(mv_output_idx);
250            }
251        }
252        let mv_agg_to_mv_output = mv_agg_to_mv_output
253            .into_iter()
254            .collect::<Option<Vec<usize>>>()?;
255        Some((mv_agg.clone(), mv_agg_to_mv_output))
256    }
257}