1use std::collections::HashMap;
16
17use super::prelude::{PlanRef, *};
18use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall};
19use crate::optimizer::optimizer_context::MaterializedViewCandidate;
20use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef, PlanAggCall};
21use crate::optimizer::plan_node::{
22 LogicalAgg, LogicalFilter, LogicalMultiJoinBuilder, LogicalPlanRef, LogicalProject,
23 LogicalScan, PlanTreeNodeUnary,
24};
25use crate::optimizer::rule::{BoxedRule, Rule};
26use crate::utils::{ColIndexMapping, Condition, IndexSet};
27
28pub struct MvSelectionRule;
29
30impl MvSelectionRule {
31 pub fn create() -> BoxedRule<Logical> {
32 Box::new(Self)
33 }
34}
35
36impl Rule<Logical> for MvSelectionRule {
37 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
38 let ctx = plan.ctx();
39 for candidate in ctx.batch_mview_candidates().iter() {
40 if let Some(rewritten) = Self::rewrite_with_candidate(&plan, candidate) {
41 return Some(rewritten);
42 }
43 }
44 None
45 }
46}
47
48impl MvSelectionRule {
49 fn rewrite_with_candidate(
50 plan: &PlanRef,
51 candidate: &MaterializedViewCandidate,
52 ) -> Option<PlanRef> {
53 if plan == &candidate.plan {
54 return Some(Self::candidate_scan(candidate, plan.ctx())?.into());
55 }
56 Self::inner_join_rewrite(plan, candidate)
57 .or_else(|| Self::agg_rollup_rewrite(plan, candidate))
58 }
59
60 fn candidate_scan(
61 candidate: &MaterializedViewCandidate,
62 ctx: crate::optimizer::optimizer_context::OptimizerContextRef,
63 ) -> Option<LogicalScan> {
64 let output_len = candidate.plan.schema().len();
65 debug_assert!(output_len <= candidate.table.columns().len());
66 let output_col_idx = (0..output_len).collect();
67 Some(
68 LogicalScan::create(candidate.table.clone(), ctx, None)
69 .clone_with_output_indices(output_col_idx),
70 )
71 }
72
73 fn agg_rollup_rewrite(
97 plan: &PlanRef,
98 candidate: &MaterializedViewCandidate,
99 ) -> Option<PlanRef> {
100 let query_agg = plan.as_logical_agg()?;
101 let (mv_agg, mv_agg_to_mv_output) = Self::extract_mv_candidate_agg(candidate)?;
102
103 if !query_agg.grouping_sets().is_empty() || !mv_agg.grouping_sets().is_empty() {
104 return None;
105 }
106 if query_agg.agg_calls().iter().any(|call| call.distinct)
108 || mv_agg.agg_calls().iter().any(|call| call.distinct)
109 {
110 return None;
111 }
112
113 let (query_base_input, query_input_to_base) =
114 Self::agg_input_to_base_mapping(query_agg.clone())?;
115 let (mv_base_input, mv_input_to_base) = Self::agg_input_to_base_mapping(mv_agg.clone())?;
116 if query_base_input != mv_base_input {
117 return None;
118 }
119
120 let mut mv_group_key_pos = HashMap::new();
121 for (pos, col_idx) in mv_agg.group_key().indices().enumerate() {
122 let mv_output_col_idx = *mv_agg_to_mv_output.get(pos)?;
123 mv_group_key_pos.insert(*mv_input_to_base.get(col_idx)?, mv_output_col_idx);
124 }
125
126 let mut query_group_key_in_mv_output = Vec::with_capacity(query_agg.group_key().len());
127 for col_idx in query_agg.group_key().indices() {
128 let base_col_idx = *query_input_to_base.get(col_idx)?;
129 query_group_key_in_mv_output.push(*mv_group_key_pos.get(&base_col_idx)?);
130 }
131
132 let mut normalized_call_to_mv_idx: HashMap<PlanAggCall, usize> = HashMap::new();
133 let mut rewritten_agg_calls = Vec::with_capacity(query_agg.agg_calls().len());
134 for query_call in query_agg.agg_calls() {
135 let normalized_query_call = Self::normalize_agg_call(query_call, &query_input_to_base)?;
136 let mv_call_idx = if let Some(mv_call_idx) = normalized_call_to_mv_idx
137 .get(&normalized_query_call)
138 .copied()
139 {
140 mv_call_idx
141 } else {
142 let mv_call_idx = mv_agg
143 .agg_calls()
144 .iter()
145 .enumerate()
146 .find(|(_, mv_call)| {
147 Self::normalize_agg_call(mv_call, &mv_input_to_base).is_some_and(
148 |normalized_mv_call| normalized_mv_call == normalized_query_call,
149 )
150 })?
151 .0;
152 normalized_call_to_mv_idx.insert(normalized_query_call.clone(), mv_call_idx);
153 mv_call_idx
154 };
155
156 query_call.agg_type.partial_to_total()?;
158 let mv_agg_output_col_idx = mv_agg.group_key().len() + mv_call_idx;
159 let mv_output_col_idx = *mv_agg_to_mv_output.get(mv_agg_output_col_idx)?;
160 rewritten_agg_calls.push(query_call.partial_to_total_agg_call(mv_output_col_idx));
161 }
162
163 let mv_scan: LogicalPlanRef = Self::candidate_scan(candidate, plan.ctx())?.into();
164
165 let rewritten_group_key = IndexSet::from_iter(query_group_key_in_mv_output.iter().copied());
166 let rewritten_agg: LogicalPlanRef =
167 Agg::new(rewritten_agg_calls, rewritten_group_key.clone(), mv_scan)
168 .with_enable_two_phase(query_agg.core().two_phase_agg_enabled())
169 .into();
170
171 let mut output_col_idx =
172 Vec::with_capacity(query_agg.group_key().len() + query_agg.agg_calls().len());
173 for group_col in query_group_key_in_mv_output {
174 let col_pos = rewritten_group_key
175 .indices()
176 .position(|idx| idx == group_col)
177 .expect("group key must exist");
178 output_col_idx.push(col_pos);
179 }
180 output_col_idx
181 .extend((0..query_agg.agg_calls().len()).map(|idx| rewritten_group_key.len() + idx));
182
183 if output_col_idx.iter().copied().eq(0..output_col_idx.len()) {
184 Some(rewritten_agg)
185 } else {
186 Some(LogicalProject::with_out_col_idx(rewritten_agg, output_col_idx.into_iter()).into())
187 }
188 }
189
190 fn agg_input_to_base_mapping(agg: LogicalAgg) -> Option<(PlanRef, Vec<usize>)> {
191 let agg_input = agg.input();
192 if let Some(proj) = agg_input.as_logical_project() {
193 let input_to_base = proj.try_as_projection()?;
194 if let Some(scan) = proj.input().as_logical_scan() {
195 let scan_base_columns = input_to_base
196 .iter()
197 .map(|idx| scan.output_col_idx().get(*idx).copied())
198 .collect::<Option<Vec<usize>>>()?;
199
200 let canonical_scan =
203 scan.clone_with_output_indices((0..scan.table().columns().len()).collect());
204 Some((canonical_scan.into(), scan_base_columns))
205 } else {
206 Some((proj.input(), input_to_base))
207 }
208 } else if let Some(scan) = agg_input.as_logical_scan() {
209 let scan_base_columns = scan.output_col_idx().clone();
210 let canonical_scan =
211 scan.clone_with_output_indices((0..scan.table().columns().len()).collect());
212 Some((canonical_scan.into(), scan_base_columns))
213 } else {
214 None
215 }
216 }
217
218 fn normalize_agg_call(call: &PlanAggCall, input_to_base: &[usize]) -> Option<PlanAggCall> {
219 let mut normalized = call.clone();
220 let index_mapping = ColIndexMapping::new(
221 input_to_base
222 .iter()
223 .copied()
224 .map(Some)
225 .collect::<Vec<Option<usize>>>(),
226 input_to_base.iter().max().copied().unwrap_or(0) + 1,
227 );
228 normalized.rewrite_input_index(index_mapping);
229 Some(normalized)
230 }
231
232 fn extract_mv_candidate_agg(
233 candidate: &MaterializedViewCandidate,
234 ) -> Option<(LogicalAgg, Vec<usize>)> {
235 if let Some(mv_agg) = candidate.plan.as_logical_agg() {
236 return Some((mv_agg.clone(), (0..mv_agg.schema().len()).collect()));
237 }
238
239 let mv_project = candidate.plan.as_logical_project()?;
240 let mv_project_input = mv_project.input();
241 let mv_agg = mv_project_input.as_logical_agg()?;
242 let mv_output_to_mv_agg = mv_project.try_as_projection()?;
243
244 let mut mv_agg_to_mv_output = vec![None; mv_agg.schema().len()];
245 for (mv_output_idx, mv_agg_idx) in mv_output_to_mv_agg.into_iter().enumerate() {
246 if mv_agg_idx >= mv_agg_to_mv_output.len() {
247 return None;
248 }
249 if mv_agg_to_mv_output[mv_agg_idx].is_none() {
250 mv_agg_to_mv_output[mv_agg_idx] = Some(mv_output_idx);
253 }
254 }
255 let mv_agg_to_mv_output = mv_agg_to_mv_output
256 .into_iter()
257 .collect::<Option<Vec<usize>>>()?;
258 Some((mv_agg.clone(), mv_agg_to_mv_output))
259 }
260
261 fn inner_join_rewrite(
262 plan: &PlanRef,
263 candidate: &MaterializedViewCandidate,
264 ) -> Option<PlanRef> {
265 let query_join = Self::extract_inner_join_rewrite(plan)?;
266 let mv_join = Self::extract_inner_join_rewrite(&candidate.plan)?;
267 if query_join.inputs.len() != mv_join.inputs.len() {
268 return None;
269 }
270
271 let query_input_to_query = (0..query_join.inputs.len()).collect::<Vec<_>>();
272 let query_base_offsets = Self::input_base_offsets(&query_join.inputs);
273 let total_base_len = query_join.total_base_len();
274
275 let query_conditions = Self::normalize_join_conditions(
276 &query_join,
277 &query_input_to_query,
278 &query_base_offsets,
279 total_base_len,
280 )?;
281
282 let query_output_to_base =
283 Self::normalize_join_outputs(&query_join, &query_input_to_query, &query_base_offsets)?;
284
285 for mv_input_to_query in Self::match_join_inputs(&query_join.inputs, &mv_join.inputs) {
286 let Some(mv_conditions) = Self::normalize_join_conditions(
287 &mv_join,
288 &mv_input_to_query,
289 &query_base_offsets,
290 total_base_len,
291 ) else {
292 continue;
293 };
294 let Some(rewritten_predicate) =
295 Self::subtract_conditions(query_conditions.clone(), mv_conditions)
296 else {
297 continue;
298 };
299 let Some(mv_output_to_base) =
300 Self::normalize_join_outputs(&mv_join, &mv_input_to_query, &query_base_offsets)
301 else {
302 continue;
303 };
304 let Some(base_to_mv_output) =
305 Self::invert_mapping(&mv_output_to_base, query_join.total_base_len())
306 else {
307 continue;
308 };
309 let Some(rewritten_predicate) =
310 Self::rewrite_condition_to_mv(rewritten_predicate, &base_to_mv_output)
311 else {
312 continue;
313 };
314 let Some(output_col_idx) = query_output_to_base
315 .iter()
316 .map(|base_idx| base_to_mv_output.get(*base_idx).copied().flatten())
317 .collect::<Option<Vec<_>>>()
318 else {
319 continue;
320 };
321
322 let mv_scan: LogicalPlanRef = Self::candidate_scan(candidate, plan.ctx())?.into();
323 let rewritten = LogicalFilter::create(mv_scan, rewritten_predicate);
324 return if output_col_idx.iter().copied().eq(0..output_col_idx.len()) {
325 Some(rewritten)
326 } else {
327 Some(LogicalProject::with_out_col_idx(rewritten, output_col_idx.into_iter()).into())
328 };
329 }
330
331 None
332 }
333
334 fn extract_inner_join_rewrite(plan: &PlanRef) -> Option<InnerJoinRewrite> {
335 let multijoin = LogicalMultiJoinBuilder::new(plan.clone());
336 let (output_indices, conjunctions, inputs, _) = multijoin.into_parts();
337 if inputs.len() < 2 {
338 return None;
339 }
340
341 let inputs = inputs
342 .iter()
343 .map(Self::extract_join_leaf)
344 .collect::<Option<Vec<_>>>()?;
345
346 Some(InnerJoinRewrite {
347 inputs,
348 conditions: Condition { conjunctions },
349 output_indices,
350 })
351 }
352
353 fn extract_join_leaf(plan: &PlanRef) -> Option<JoinLeafRewrite> {
354 if let Some(scan) = plan.as_logical_scan() {
355 let output_to_base = scan.output_col_idx().clone();
356 let base_scan =
357 scan.clone_with_output_indices((0..scan.table().columns().len()).collect());
358 let predicate = Self::rewrite_condition(
359 scan.predicate().clone(),
360 output_to_base.clone(),
361 base_scan.schema().len(),
362 )?;
363 return Some(JoinLeafRewrite {
364 base_scan: base_scan.into(),
365 output_to_base,
366 predicate,
367 });
368 }
369 if let Some(filter) = plan.as_logical_filter() {
370 let mut child = Self::extract_join_leaf(&filter.input())?;
371 let predicate = Self::rewrite_condition(
372 filter.predicate().clone(),
373 child.output_to_base.clone(),
374 child.base_len(),
375 )?;
376 child.predicate = child.predicate.and(predicate);
377 return Some(child);
378 }
379 if let Some(project) = plan.as_logical_project() {
380 let child = Self::extract_join_leaf(&project.input())?;
381 let output_to_base = project
382 .try_as_projection()?
383 .into_iter()
384 .map(|idx| child.output_to_base.get(idx).copied())
385 .collect::<Option<Vec<_>>>()?;
386 return Some(JoinLeafRewrite {
387 base_scan: child.base_scan,
388 output_to_base,
389 predicate: child.predicate,
390 });
391 }
392 None
393 }
394
395 fn input_base_offsets(inputs: &[JoinLeafRewrite]) -> Vec<usize> {
396 let mut offsets = Vec::with_capacity(inputs.len());
397 let mut offset = 0;
398 for input in inputs {
399 offsets.push(offset);
400 offset += input.base_len();
401 }
402 offsets
403 }
404
405 fn match_join_inputs(
406 query_inputs: &[JoinLeafRewrite],
407 mv_inputs: &[JoinLeafRewrite],
408 ) -> Vec<Vec<usize>> {
409 fn dfs(
410 mv_idx: usize,
411 query_inputs: &[JoinLeafRewrite],
412 mv_inputs: &[JoinLeafRewrite],
413 used_query: &mut [bool],
414 mapping: &mut [usize],
415 results: &mut Vec<Vec<usize>>,
416 ) {
417 if mv_idx == mv_inputs.len() {
418 results.push(mapping.to_vec());
419 return;
420 }
421 for query_idx in 0..query_inputs.len() {
422 if used_query[query_idx]
423 || mv_inputs[mv_idx].base_scan != query_inputs[query_idx].base_scan
424 || mv_inputs[mv_idx].base_len() != query_inputs[query_idx].base_len()
425 {
426 continue;
427 }
428 used_query[query_idx] = true;
429 mapping[mv_idx] = query_idx;
430 dfs(
431 mv_idx + 1,
432 query_inputs,
433 mv_inputs,
434 used_query,
435 mapping,
436 results,
437 );
438 used_query[query_idx] = false;
439 }
440 }
441
442 let mut used_query = vec![false; query_inputs.len()];
443 let mut mapping = vec![usize::MAX; mv_inputs.len()];
444 let mut results = vec![];
445 dfs(
446 0,
447 query_inputs,
448 mv_inputs,
449 &mut used_query,
450 &mut mapping,
451 &mut results,
452 );
453 results
454 }
455
456 fn normalize_join_conditions(
457 join: &InnerJoinRewrite,
458 input_to_query: &[usize],
459 query_base_offsets: &[usize],
460 total_base_len: usize,
461 ) -> Option<Condition> {
462 let join_output_to_base =
463 Self::join_output_to_base_mapping(join, input_to_query, query_base_offsets);
464 let mut conditions =
465 Self::rewrite_condition(join.conditions.clone(), join_output_to_base, total_base_len)?;
466 for (input_idx, input) in join.inputs.iter().enumerate() {
467 conditions = conditions.and(Self::shift_condition_to_join_base(
468 input.predicate.clone(),
469 *query_base_offsets.get(*input_to_query.get(input_idx)?)?,
470 input.base_len(),
471 total_base_len,
472 )?);
473 }
474 Some(conditions)
475 }
476
477 fn normalize_join_outputs(
478 join: &InnerJoinRewrite,
479 input_to_query: &[usize],
480 query_base_offsets: &[usize],
481 ) -> Option<Vec<usize>> {
482 let join_output_to_base =
483 Self::join_output_to_base_mapping(join, input_to_query, query_base_offsets);
484 join.output_indices
485 .iter()
486 .map(|idx| join_output_to_base.get(*idx).copied())
487 .collect()
488 }
489
490 fn join_output_to_base_mapping(
491 join: &InnerJoinRewrite,
492 input_to_query: &[usize],
493 query_base_offsets: &[usize],
494 ) -> Vec<usize> {
495 join.inputs
496 .iter()
497 .enumerate()
498 .flat_map(|(input_idx, input)| {
499 let offset = query_base_offsets[input_to_query[input_idx]];
500 input
501 .output_to_base
502 .iter()
503 .copied()
504 .map(move |base_idx| base_idx + offset)
505 })
506 .collect()
507 }
508
509 fn subtract_conditions(query: Condition, mv: Condition) -> Option<Condition> {
510 let mut counts: HashMap<crate::expr::ExprImpl, usize> = HashMap::new();
511 for expr in query.conjunctions {
512 *counts.entry(expr).or_default() += 1;
513 }
514 for expr in mv.conjunctions {
515 let count = counts.get_mut(&expr)?;
516 *count -= 1;
517 if *count == 0 {
518 counts.remove(&expr);
519 }
520 }
521 let residual = counts
522 .into_iter()
523 .flat_map(|(expr, count)| std::iter::repeat_n(expr, count))
524 .collect();
525 Some(Self::canonicalize_condition(Condition {
526 conjunctions: residual,
527 }))
528 }
529
530 fn rewrite_condition(
531 condition: Condition,
532 output_to_base: Vec<usize>,
533 target_size: usize,
534 ) -> Option<Condition> {
535 let mut mapping =
536 ColIndexMapping::new(output_to_base.into_iter().map(Some).collect(), target_size);
537 Some(Self::canonicalize_condition(
538 condition.rewrite_expr(&mut mapping),
539 ))
540 }
541
542 fn shift_condition_to_join_base(
543 condition: Condition,
544 offset: usize,
545 source_size: usize,
546 target_size: usize,
547 ) -> Option<Condition> {
548 let mut mapping = ColIndexMapping::new(
549 (0..source_size).map(|idx| Some(idx + offset)).collect(),
550 target_size,
551 );
552 Some(Self::canonicalize_condition(
553 condition.rewrite_expr(&mut mapping),
554 ))
555 }
556
557 fn invert_mapping(mapping: &[usize], source_size: usize) -> Option<Vec<Option<usize>>> {
558 let mut inverse = vec![None; source_size];
559 for (target_idx, source_idx) in mapping.iter().copied().enumerate() {
560 if let slot @ None = inverse.get_mut(source_idx)? {
561 *slot = Some(target_idx);
562 }
563 }
564 Some(inverse)
565 }
566
567 fn rewrite_condition_to_mv(
568 condition: Condition,
569 base_to_mv_output: &[Option<usize>],
570 ) -> Option<Condition> {
571 let input_refs = condition.collect_input_refs(base_to_mv_output.len());
572 if input_refs
573 .ones()
574 .any(|idx| base_to_mv_output[idx].is_none())
575 {
576 return None;
577 }
578 let target_size = base_to_mv_output
579 .iter()
580 .flatten()
581 .max()
582 .copied()
583 .map_or(0, |idx| idx + 1);
584 let mut mapping = ColIndexMapping::new(base_to_mv_output.to_vec(), target_size);
585 Some(Self::canonicalize_condition(
586 condition.rewrite_expr(&mut mapping),
587 ))
588 }
589
590 fn canonicalize_condition(condition: Condition) -> Condition {
591 let mut conjunctions = condition
592 .conjunctions
593 .into_iter()
594 .map(Self::canonicalize_expr)
595 .collect::<Vec<_>>();
596 conjunctions.sort_by_cached_key(|expr| format!("{expr:?}"));
597 Condition { conjunctions }
598 }
599
600 fn canonicalize_expr(expr: crate::expr::ExprImpl) -> crate::expr::ExprImpl {
601 if let Some((input_ref, const_expr)) = expr.as_eq_const() {
602 return FunctionCall::new_unchecked(
603 ExprType::Equal,
604 vec![input_ref.into(), const_expr],
605 expr.return_type(),
606 )
607 .into();
608 }
609 if let Some((lhs, rhs)) = expr.as_eq_cond() {
610 return FunctionCall::new_unchecked(
611 ExprType::Equal,
612 vec![lhs.into(), rhs.into()],
613 expr.return_type(),
614 )
615 .into();
616 }
617 if let ExprImpl::FunctionCall(function_call) = &expr
618 && function_call.func_type() == ExprType::IsNotDistinctFrom
619 {
620 let (_, lhs, rhs) = function_call.clone().decompose_as_binary();
621 match (lhs, rhs) {
622 (ExprImpl::InputRef(input_ref), const_expr) if const_expr.is_const() => {
623 return FunctionCall::new_unchecked(
624 ExprType::IsNotDistinctFrom,
625 vec![(*input_ref).into(), const_expr],
626 expr.return_type(),
627 )
628 .into();
629 }
630 (const_expr, ExprImpl::InputRef(input_ref)) if const_expr.is_const() => {
631 return FunctionCall::new_unchecked(
632 ExprType::IsNotDistinctFrom,
633 vec![(*input_ref).into(), const_expr],
634 expr.return_type(),
635 )
636 .into();
637 }
638 _ => {}
639 }
640 }
641 if let Some((lhs, rhs)) = expr.as_is_not_distinct_from_cond() {
642 return FunctionCall::new_unchecked(
643 ExprType::IsNotDistinctFrom,
644 vec![lhs.into(), rhs.into()],
645 expr.return_type(),
646 )
647 .into();
648 }
649 expr
650 }
651}
652
653#[derive(Clone)]
654struct JoinLeafRewrite {
655 base_scan: LogicalPlanRef,
656 output_to_base: Vec<usize>,
657 predicate: Condition,
658}
659
660impl JoinLeafRewrite {
661 fn base_len(&self) -> usize {
662 self.base_scan.schema().len()
663 }
664}
665
666#[derive(Clone)]
667struct InnerJoinRewrite {
668 inputs: Vec<JoinLeafRewrite>,
669 conditions: Condition,
670 output_indices: Vec<usize>,
671}
672
673impl InnerJoinRewrite {
674 fn total_base_len(&self) -> usize {
675 self.inputs.iter().map(JoinLeafRewrite::base_len).sum()
676 }
677}