risingwave_frontend/optimizer/plan_node/
predicate_pushdown.rsuse std::collections::HashMap;
use paste::paste;
use super::*;
use crate::optimizer::plan_visitor::ShareParentCounter;
use crate::optimizer::PlanVisitor;
use crate::{for_batch_plan_nodes, for_stream_plan_nodes};
pub trait PredicatePushdown {
fn predicate_pushdown(
&self,
predicate: Condition,
ctx: &mut PredicatePushdownContext,
) -> PlanRef;
}
macro_rules! ban_predicate_pushdown {
($( { $convention:ident, $name:ident }),*) => {
paste!{
$(impl PredicatePushdown for [<$convention $name>] {
fn predicate_pushdown(&self, _predicate: Condition, _ctx: &mut PredicatePushdownContext) -> PlanRef {
unreachable!("predicate pushdown is only allowed on logical plan")
}
})*
}
}
}
for_batch_plan_nodes! {ban_predicate_pushdown}
for_stream_plan_nodes! {ban_predicate_pushdown}
#[inline]
pub fn gen_filter_and_pushdown<T: PlanTreeNodeUnary + PlanNode>(
node: &T,
filter_predicate: Condition,
pushed_predicate: Condition,
ctx: &mut PredicatePushdownContext,
) -> PlanRef {
let new_input = node.input().predicate_pushdown(pushed_predicate, ctx);
let new_node = node.clone_with_input(new_input);
LogicalFilter::create(new_node.into(), filter_predicate)
}
#[derive(Debug, Clone)]
pub struct PredicatePushdownContext {
share_predicate_map: HashMap<PlanNodeId, Vec<Condition>>,
share_parent_counter: ShareParentCounter,
}
impl PredicatePushdownContext {
pub fn new(root: PlanRef) -> Self {
let mut share_parent_counter = ShareParentCounter::default();
share_parent_counter.visit(root);
Self {
share_predicate_map: Default::default(),
share_parent_counter,
}
}
pub fn get_parent_num(&self, share: &LogicalShare) -> usize {
self.share_parent_counter.get_parent_num(share)
}
pub fn add_predicate(&mut self, plan_node_id: PlanNodeId, predicate: Condition) -> usize {
self.share_predicate_map
.entry(plan_node_id)
.and_modify(|e| e.push(predicate.clone()))
.or_insert_with(|| vec![predicate])
.len()
}
pub fn take_predicate(&mut self, plan_node_id: PlanNodeId) -> Option<Vec<Condition>> {
self.share_predicate_map.remove(&plan_node_id)
}
}