risingwave_frontend/optimizer/rule/
mv_selection_rule.rs1use std::collections::HashMap;
16
17use super::prelude::{PlanRef, *};
18use crate::optimizer::optimizer_context::MaterializedViewCandidate;
19use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef, PlanAggCall};
20use crate::optimizer::plan_node::{
21 LogicalAgg, LogicalPlanRef, LogicalProject, LogicalScan, PlanTreeNodeUnary,
22};
23use crate::optimizer::rule::{BoxedRule, Rule};
24use crate::utils::{ColIndexMapping, IndexSet};
25
26pub struct MvSelectionRule;
27
28impl MvSelectionRule {
29 pub fn create() -> BoxedRule<Logical> {
30 Box::new(Self)
31 }
32}
33
34impl Rule<Logical> for MvSelectionRule {
35 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
36 let ctx = plan.ctx();
37 for candidate in ctx.batch_mview_candidates().iter() {
38 if let Some(rewritten) = Self::rewrite_with_candidate(&plan, candidate) {
39 return Some(rewritten);
40 }
41 }
42 None
43 }
44}
45
46impl MvSelectionRule {
47 fn rewrite_with_candidate(
48 plan: &PlanRef,
49 candidate: &MaterializedViewCandidate,
50 ) -> Option<PlanRef> {
51 if plan == &candidate.plan {
52 return Some(Self::candidate_scan(candidate, plan.ctx())?.into());
53 }
54 Self::agg_rollup_rewrite(plan, candidate)
55 }
56
57 fn candidate_scan(
58 candidate: &MaterializedViewCandidate,
59 ctx: crate::optimizer::optimizer_context::OptimizerContextRef,
60 ) -> Option<LogicalScan> {
61 let output_len = candidate.plan.schema().len();
62 debug_assert!(output_len <= candidate.table.columns().len());
63 let output_col_idx = (0..output_len).collect();
64 Some(
65 LogicalScan::create(candidate.table.clone(), ctx, None)
66 .clone_with_output_indices(output_col_idx),
67 )
68 }
69
70 fn agg_rollup_rewrite(
94 plan: &PlanRef,
95 candidate: &MaterializedViewCandidate,
96 ) -> Option<PlanRef> {
97 let query_agg = plan.as_logical_agg()?;
98 let (mv_agg, mv_agg_to_mv_output) = Self::extract_mv_candidate_agg(candidate)?;
99
100 if !query_agg.grouping_sets().is_empty() || !mv_agg.grouping_sets().is_empty() {
101 return None;
102 }
103 if query_agg.agg_calls().iter().any(|call| call.distinct)
105 || mv_agg.agg_calls().iter().any(|call| call.distinct)
106 {
107 return None;
108 }
109
110 let (query_base_input, query_input_to_base) =
111 Self::agg_input_to_base_mapping(query_agg.clone())?;
112 let (mv_base_input, mv_input_to_base) = Self::agg_input_to_base_mapping(mv_agg.clone())?;
113 if query_base_input != mv_base_input {
114 return None;
115 }
116
117 let mut mv_group_key_pos = HashMap::new();
118 for (pos, col_idx) in mv_agg.group_key().indices().enumerate() {
119 let mv_output_col_idx = *mv_agg_to_mv_output.get(pos)?;
120 mv_group_key_pos.insert(*mv_input_to_base.get(col_idx)?, mv_output_col_idx);
121 }
122
123 let mut query_group_key_in_mv_output = Vec::with_capacity(query_agg.group_key().len());
124 for col_idx in query_agg.group_key().indices() {
125 let base_col_idx = *query_input_to_base.get(col_idx)?;
126 query_group_key_in_mv_output.push(*mv_group_key_pos.get(&base_col_idx)?);
127 }
128
129 let mut normalized_call_to_mv_idx: HashMap<PlanAggCall, usize> = HashMap::new();
130 let mut rewritten_agg_calls = Vec::with_capacity(query_agg.agg_calls().len());
131 for query_call in query_agg.agg_calls() {
132 let normalized_query_call = Self::normalize_agg_call(query_call, &query_input_to_base)?;
133 let mv_call_idx = if let Some(mv_call_idx) = normalized_call_to_mv_idx
134 .get(&normalized_query_call)
135 .copied()
136 {
137 mv_call_idx
138 } else {
139 let mv_call_idx = mv_agg
140 .agg_calls()
141 .iter()
142 .enumerate()
143 .find(|(_, mv_call)| {
144 Self::normalize_agg_call(mv_call, &mv_input_to_base).is_some_and(
145 |normalized_mv_call| normalized_mv_call == normalized_query_call,
146 )
147 })?
148 .0;
149 normalized_call_to_mv_idx.insert(normalized_query_call.clone(), mv_call_idx);
150 mv_call_idx
151 };
152
153 query_call.agg_type.partial_to_total()?;
155 let mv_agg_output_col_idx = mv_agg.group_key().len() + mv_call_idx;
156 let mv_output_col_idx = *mv_agg_to_mv_output.get(mv_agg_output_col_idx)?;
157 rewritten_agg_calls.push(query_call.partial_to_total_agg_call(mv_output_col_idx));
158 }
159
160 let mv_scan: LogicalPlanRef = Self::candidate_scan(candidate, plan.ctx())?.into();
161
162 let rewritten_group_key = IndexSet::from_iter(query_group_key_in_mv_output.iter().copied());
163 let rewritten_agg: LogicalPlanRef =
164 Agg::new(rewritten_agg_calls, rewritten_group_key.clone(), mv_scan)
165 .with_enable_two_phase(query_agg.core().two_phase_agg_enabled())
166 .into();
167
168 let mut output_col_idx =
169 Vec::with_capacity(query_agg.group_key().len() + query_agg.agg_calls().len());
170 for group_col in query_group_key_in_mv_output {
171 let col_pos = rewritten_group_key
172 .indices()
173 .position(|idx| idx == group_col)
174 .expect("group key must exist");
175 output_col_idx.push(col_pos);
176 }
177 output_col_idx
178 .extend((0..query_agg.agg_calls().len()).map(|idx| rewritten_group_key.len() + idx));
179
180 if output_col_idx.iter().copied().eq(0..output_col_idx.len()) {
181 Some(rewritten_agg)
182 } else {
183 Some(LogicalProject::with_out_col_idx(rewritten_agg, output_col_idx.into_iter()).into())
184 }
185 }
186
187 fn agg_input_to_base_mapping(agg: LogicalAgg) -> Option<(PlanRef, Vec<usize>)> {
188 let agg_input = agg.input();
189 if let Some(proj) = agg_input.as_logical_project() {
190 let input_to_base = proj.try_as_projection()?;
191 if let Some(scan) = proj.input().as_logical_scan() {
192 let scan_base_columns = input_to_base
193 .iter()
194 .map(|idx| scan.output_col_idx().get(*idx).copied())
195 .collect::<Option<Vec<usize>>>()?;
196
197 let canonical_scan =
200 scan.clone_with_output_indices((0..scan.table().columns().len()).collect());
201 Some((canonical_scan.into(), scan_base_columns))
202 } else {
203 Some((proj.input(), input_to_base))
204 }
205 } else if let Some(scan) = agg_input.as_logical_scan() {
206 let scan_base_columns = scan.output_col_idx().clone();
207 let canonical_scan =
208 scan.clone_with_output_indices((0..scan.table().columns().len()).collect());
209 Some((canonical_scan.into(), scan_base_columns))
210 } else {
211 None
212 }
213 }
214
215 fn normalize_agg_call(call: &PlanAggCall, input_to_base: &[usize]) -> Option<PlanAggCall> {
216 let mut normalized = call.clone();
217 let index_mapping = ColIndexMapping::new(
218 input_to_base
219 .iter()
220 .copied()
221 .map(Some)
222 .collect::<Vec<Option<usize>>>(),
223 input_to_base.iter().max().copied().unwrap_or(0) + 1,
224 );
225 normalized.rewrite_input_index(index_mapping);
226 Some(normalized)
227 }
228
229 fn extract_mv_candidate_agg(
230 candidate: &MaterializedViewCandidate,
231 ) -> Option<(LogicalAgg, Vec<usize>)> {
232 if let Some(mv_agg) = candidate.plan.as_logical_agg() {
233 return Some((mv_agg.clone(), (0..mv_agg.schema().len()).collect()));
234 }
235
236 let mv_project = candidate.plan.as_logical_project()?;
237 let mv_project_input = mv_project.input();
238 let mv_agg = mv_project_input.as_logical_agg()?;
239 let mv_output_to_mv_agg = mv_project.try_as_projection()?;
240
241 let mut mv_agg_to_mv_output = vec![None; mv_agg.schema().len()];
242 for (mv_output_idx, mv_agg_idx) in mv_output_to_mv_agg.into_iter().enumerate() {
243 if mv_agg_idx >= mv_agg_to_mv_output.len() {
244 return None;
245 }
246 if mv_agg_to_mv_output[mv_agg_idx].is_none() {
247 mv_agg_to_mv_output[mv_agg_idx] = Some(mv_output_idx);
250 }
251 }
252 let mv_agg_to_mv_output = mv_agg_to_mv_output
253 .into_iter()
254 .collect::<Option<Vec<usize>>>()?;
255 Some((mv_agg.clone(), mv_agg_to_mv_output))
256 }
257}