risingwave_frontend/optimizer/plan_node/
batch_hash_join.rs1use pretty_xmlish::{Pretty, XmlNode};
16use risingwave_pb::batch_plan::HashJoinNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType};
19
20use super::batch::prelude::*;
21use super::utils::{Distill, childless_record};
22use super::{
23    BatchPlanRef as PlanRef, EqJoinPredicate, ExprRewritable, LogicalJoin, PlanBase,
24    PlanTreeNodeBinary, ToBatchPb, ToDistributedBatch, generic,
25};
26use crate::error::Result;
27use crate::expr::{Expr, ExprRewriter, ExprVisitor};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::utils::IndicesDisplay;
30use crate::optimizer::plan_node::{EqJoinPredicateDisplay, ToLocalBatch};
31use crate::optimizer::property::{Distribution, Order, RequiredDist};
32use crate::utils::ColIndexMappingRewriteExt;
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct BatchHashJoin {
39    pub base: PlanBase<Batch>,
40    core: generic::Join<PlanRef>,
41    eq_join_predicate: EqJoinPredicate,
44    asof_desc: Option<AsOfJoinDesc>,
46}
47
48impl BatchHashJoin {
49    pub fn new(
50        core: generic::Join<PlanRef>,
51        eq_join_predicate: EqJoinPredicate,
52        asof_desc: Option<AsOfJoinDesc>,
53    ) -> Self {
54        let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core);
55        let base = PlanBase::new_batch_with_core(&core, dist, Order::any());
56
57        Self {
58            base,
59            core,
60            eq_join_predicate,
61            asof_desc,
62        }
63    }
64
65    pub(super) fn derive_dist(
66        left: &Distribution,
67        right: &Distribution,
68        join: &generic::Join<PlanRef>,
69    ) -> Distribution {
70        match (left, right) {
71            (Distribution::Single, Distribution::Single) => Distribution::Single,
72            (Distribution::HashShard(_), Distribution::HashShard(_)) => match join.join_type {
75                JoinType::Unspecified => {
76                    unreachable!()
77                }
78                JoinType::FullOuter => Distribution::SomeShard,
79                JoinType::Inner
80                | JoinType::LeftOuter
81                | JoinType::LeftSemi
82                | JoinType::LeftAnti
83                | JoinType::AsofInner
84                | JoinType::AsofLeftOuter => {
85                    let l2o = join.l2i_col_mapping().composite(&join.i2o_col_mapping());
86                    l2o.rewrite_provided_distribution(left)
87                }
88                JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
89                    let r2o = join.r2i_col_mapping().composite(&join.i2o_col_mapping());
90                    r2o.rewrite_provided_distribution(right)
91                }
92            },
93            (_, _) => unreachable!(
94                "suspicious distribution: left: {:?}, right: {:?}",
95                left, right
96            ),
97        }
98    }
99
100    pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
102        &self.eq_join_predicate
103    }
104}
105
106impl Distill for BatchHashJoin {
107    fn distill<'a>(&self) -> XmlNode<'a> {
108        let verbose = self.base.ctx().is_explain_verbose();
109        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
110        vec.push(("type", Pretty::debug(&self.core.join_type)));
111
112        let concat_schema = self.core.concat_schema();
113        vec.push((
114            "predicate",
115            Pretty::debug(&EqJoinPredicateDisplay {
116                eq_join_predicate: self.eq_join_predicate(),
117                input_schema: &concat_schema,
118            }),
119        ));
120        if verbose {
121            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
122            vec.push(("output", data));
123        }
124        childless_record("BatchHashJoin", vec)
125    }
126}
127
128impl PlanTreeNodeBinary<Batch> for BatchHashJoin {
129    fn left(&self) -> PlanRef {
130        self.core.left.clone()
131    }
132
133    fn right(&self) -> PlanRef {
134        self.core.right.clone()
135    }
136
137    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
138        let mut core = self.core.clone();
139        core.left = left;
140        core.right = right;
141        Self::new(core, self.eq_join_predicate.clone(), self.asof_desc)
142    }
143}
144
145impl_plan_tree_node_for_binary! { Batch, BatchHashJoin }
146
147impl ToDistributedBatch for BatchHashJoin {
148    fn to_distributed(&self) -> Result<PlanRef> {
149        let mut right = self.right().to_distributed_with_required(
150            &Order::any(),
151            &RequiredDist::shard_by_key(
152                self.right().schema().len(),
153                &self.eq_join_predicate().right_eq_indexes(),
154            ),
155        )?;
156        let mut left = self.left();
157
158        let r2l = self
159            .eq_join_predicate()
160            .r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
161        let l2r = self
162            .eq_join_predicate()
163            .l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
164
165        let right_dist = right.distribution();
166        match right_dist {
167            Distribution::HashShard(_) => {
168                let left_dist = r2l
169                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
170                left = left.to_distributed_with_required(&Order::any(), &left_dist)?;
171            }
172            Distribution::UpstreamHashShard(_, _) => {
173                left = left.to_distributed_with_required(
174                    &Order::any(),
175                    &RequiredDist::shard_by_key(
176                        self.left().schema().len(),
177                        &self.eq_join_predicate().left_eq_indexes(),
178                    ),
179                )?;
180                let left_dist = left.distribution();
181                match left_dist {
182                    Distribution::HashShard(_) => {
183                        let right_dist = l2r.rewrite_required_distribution(
184                            &RequiredDist::PhysicalDist(left_dist.clone()),
185                        );
186                        right = right_dist.batch_enforce_if_not_satisfies(right, &Order::any())?
187                    }
188                    Distribution::UpstreamHashShard(_, _) => {
189                        left =
190                            RequiredDist::hash_shard(&self.eq_join_predicate().left_eq_indexes())
191                                .batch_enforce_if_not_satisfies(left, &Order::any())?;
192                        right =
193                            RequiredDist::hash_shard(&self.eq_join_predicate().right_eq_indexes())
194                                .batch_enforce_if_not_satisfies(right, &Order::any())?;
195                    }
196                    _ => unreachable!(),
197                }
198            }
199            _ => unreachable!(),
200        }
201
202        Ok(self.clone_with_left_right(left, right).into())
203    }
204}
205
206impl ToBatchPb for BatchHashJoin {
207    fn to_batch_prost_body(&self) -> NodeBody {
208        NodeBody::HashJoin(HashJoinNode {
209            join_type: self.core.join_type as i32,
210            left_key: self
211                .eq_join_predicate
212                .left_eq_indexes()
213                .into_iter()
214                .map(|a| a as i32)
215                .collect(),
216            right_key: self
217                .eq_join_predicate
218                .right_eq_indexes()
219                .into_iter()
220                .map(|a| a as i32)
221                .collect(),
222            null_safe: self.eq_join_predicate.null_safes().into_iter().collect(),
223            condition: self
224                .eq_join_predicate
225                .other_cond()
226                .as_expr_unless_true()
227                .map(|x| x.to_expr_proto()),
228            output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
229            asof_desc: self.asof_desc,
230        })
231    }
232}
233
234impl ToLocalBatch for BatchHashJoin {
235    fn to_local(&self) -> Result<PlanRef> {
236        let right = RequiredDist::single()
237            .batch_enforce_if_not_satisfies(self.right().to_local()?, &Order::any())?;
238        let left = RequiredDist::single()
239            .batch_enforce_if_not_satisfies(self.left().to_local()?, &Order::any())?;
240
241        Ok(self.clone_with_left_right(left, right).into())
242    }
243}
244
245impl ExprRewritable<Batch> for BatchHashJoin {
246    fn has_rewritable_expr(&self) -> bool {
247        true
248    }
249
250    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
251        let mut core = self.core.clone();
252        core.rewrite_exprs(r);
253        let eq_join_predicate = self.eq_join_predicate.rewrite_exprs(r);
254        let desc = self.asof_desc.map(|_| {
255            LogicalJoin::get_inequality_desc_from_predicate(
256                eq_join_predicate.other_cond().clone(),
257                core.left.schema().len(),
258            )
259            .unwrap()
260        });
261        Self::new(core, eq_join_predicate, desc).into()
262    }
263}
264
265impl ExprVisitable for BatchHashJoin {
266    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
267        self.core.visit_exprs(v);
268    }
269}