risingwave_frontend/optimizer/plan_rewriter/
share_source_rewriter.rsuse std::collections::{HashMap, HashSet};
use itertools::Itertools;
use crate::catalog::SourceId;
use crate::optimizer::plan_node::generic::GenericPlanRef;
use crate::optimizer::plan_node::{
LogicalShare, LogicalSource, PlanNodeId, PlanTreeNode, StreamShare,
};
use crate::optimizer::plan_visitor::{DefaultBehavior, DefaultValue};
use crate::optimizer::{PlanRewriter, PlanVisitor};
use crate::PlanRef;
#[derive(Debug, Clone, Default)]
pub struct ShareSourceRewriter {
share_ids: HashSet<SourceId>,
share_source: HashMap<SourceId, PlanRef>,
share_map: HashMap<PlanNodeId, PlanRef>,
}
#[derive(Debug, Clone, Default)]
struct SourceCounter {
source_counter: HashMap<SourceId, usize>,
}
impl ShareSourceRewriter {
pub fn share_source(plan: PlanRef) -> PlanRef {
let mut source_counter = SourceCounter::default();
source_counter.visit(plan.clone());
let mut share_source_rewriter = ShareSourceRewriter {
share_ids: source_counter
.source_counter
.into_iter()
.filter(|(_, v)| *v > 1)
.map(|(k, _)| k)
.collect(),
share_source: Default::default(),
share_map: Default::default(),
};
share_source_rewriter.rewrite(plan)
}
}
impl PlanRewriter for ShareSourceRewriter {
fn rewrite_logical_source(&mut self, source: &LogicalSource) -> PlanRef {
let source_id = match &source.core.catalog {
Some(s) => s.id,
None => {
return source.clone().into();
}
};
if !self.share_ids.contains(&source_id) {
let source_ref = source.clone().into();
return source_ref;
}
match self.share_source.get(&source_id) {
None => {
let source_ref = source.clone().into();
let share_source = LogicalShare::create(source_ref);
self.share_source.insert(source_id, share_source.clone());
share_source
}
Some(share_source) => share_source.clone(),
}
}
fn rewrite_logical_share(&mut self, share: &LogicalShare) -> PlanRef {
match self.share_map.get(&share.id()) {
None => {
let new_inputs = share
.inputs()
.into_iter()
.map(|input| self.rewrite(input))
.collect_vec();
let new_share = share.clone_with_inputs(&new_inputs);
self.share_map.insert(share.id(), new_share.clone());
new_share
}
Some(new_share) => new_share.clone(),
}
}
fn rewrite_stream_share(&mut self, _share: &StreamShare) -> PlanRef {
unreachable!()
}
}
impl PlanVisitor for SourceCounter {
type Result = ();
type DefaultBehavior = impl DefaultBehavior<Self::Result>;
fn default_behavior() -> Self::DefaultBehavior {
DefaultValue
}
fn visit_logical_source(&mut self, source: &LogicalSource) {
if let Some(source) = &source.core.catalog {
self.source_counter
.entry(source.id)
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
}