risingwave_frontend/optimizer/plan_node/
batch_hash_join.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use pretty_xmlish::{Pretty, XmlNode};
16use risingwave_pb::batch_plan::HashJoinNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType};
19
20use super::batch::prelude::*;
21use super::utils::{Distill, childless_record};
22use super::{
23    BatchPlanRef as PlanRef, EqJoinPredicate, ExprRewritable, LogicalJoin, PlanBase,
24    PlanTreeNodeBinary, ToBatchPb, ToDistributedBatch, generic,
25};
26use crate::error::Result;
27use crate::expr::{Expr, ExprRewriter, ExprVisitor};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::utils::IndicesDisplay;
30use crate::optimizer::plan_node::{EqJoinPredicateDisplay, ToLocalBatch};
31use crate::optimizer::property::{Distribution, Order, RequiredDist};
32use crate::utils::ColIndexMappingRewriteExt;
33
34/// `BatchHashJoin` implements [`super::LogicalJoin`] with hash table. It builds a hash table
35/// from inner (right-side) relation and then probes with data from outer (left-side) relation to
36/// get output rows.
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct BatchHashJoin {
39    pub base: PlanBase<Batch>,
40    core: generic::Join<PlanRef>,
41    /// `AsOf` desc
42    asof_desc: Option<AsOfJoinDesc>,
43}
44
45impl BatchHashJoin {
46    pub fn new(core: generic::Join<PlanRef>, asof_desc: Option<AsOfJoinDesc>) -> Self {
47        core.on
48            .as_eq_predicate_ref()
49            .expect("BatchHashJoin requires JoinOn::EqPredicate in core");
50        let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core);
51        let base = PlanBase::new_batch_with_core(&core, dist, Order::any());
52
53        Self {
54            base,
55            core,
56            asof_desc,
57        }
58    }
59
60    pub(super) fn derive_dist(
61        left: &Distribution,
62        right: &Distribution,
63        join: &generic::Join<PlanRef>,
64    ) -> Distribution {
65        match (left, right) {
66            (Distribution::Single, Distribution::Single) => Distribution::Single,
67            // we can not derive the hash distribution from the side where outer join can generate a
68            // NULL row
69            (Distribution::HashShard(_), Distribution::HashShard(_)) => match join.join_type {
70                JoinType::Unspecified => {
71                    unreachable!()
72                }
73                JoinType::FullOuter => Distribution::SomeShard,
74                JoinType::Inner
75                | JoinType::LeftOuter
76                | JoinType::LeftSemi
77                | JoinType::LeftAnti
78                | JoinType::AsofInner
79                | JoinType::AsofLeftOuter => {
80                    let l2o = join.l2i_col_mapping().composite(&join.i2o_col_mapping());
81                    l2o.rewrite_provided_distribution(left)
82                }
83                JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
84                    let r2o = join.r2i_col_mapping().composite(&join.i2o_col_mapping());
85                    r2o.rewrite_provided_distribution(right)
86                }
87            },
88            (_, _) => unreachable!(
89                "suspicious distribution: left: {:?}, right: {:?}",
90                left, right
91            ),
92        }
93    }
94
95    /// Get a reference to the batch hash join's eq join predicate.
96    pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
97        self.core
98            .on
99            .as_eq_predicate_ref()
100            .expect("BatchHashJoin should store predicate as EqJoinPredicate")
101    }
102}
103
104impl Distill for BatchHashJoin {
105    fn distill<'a>(&self) -> XmlNode<'a> {
106        let verbose = self.base.ctx().is_explain_verbose();
107        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
108        vec.push(("type", Pretty::debug(&self.core.join_type)));
109
110        let concat_schema = self.core.concat_schema();
111        vec.push((
112            "predicate",
113            Pretty::debug(&EqJoinPredicateDisplay {
114                eq_join_predicate: self.eq_join_predicate(),
115                input_schema: &concat_schema,
116            }),
117        ));
118        if verbose {
119            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
120            vec.push(("output", data));
121        }
122        childless_record("BatchHashJoin", vec)
123    }
124}
125
126impl PlanTreeNodeBinary<Batch> for BatchHashJoin {
127    fn left(&self) -> PlanRef {
128        self.core.left.clone()
129    }
130
131    fn right(&self) -> PlanRef {
132        self.core.right.clone()
133    }
134
135    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
136        let mut core = self.core.clone();
137        core.left = left;
138        core.right = right;
139        Self::new(core, self.asof_desc)
140    }
141}
142
143impl_plan_tree_node_for_binary! { Batch, BatchHashJoin }
144
145impl ToDistributedBatch for BatchHashJoin {
146    fn to_distributed(&self) -> Result<PlanRef> {
147        let mut right = self.right().to_distributed_with_required(
148            &Order::any(),
149            &RequiredDist::shard_by_key(
150                self.right().schema().len(),
151                &self.eq_join_predicate().right_eq_indexes(),
152            ),
153        )?;
154        let mut left = self.left();
155
156        let r2l = self
157            .eq_join_predicate()
158            .r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
159        let l2r = self
160            .eq_join_predicate()
161            .l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
162
163        let right_dist = right.distribution();
164        match right_dist {
165            Distribution::HashShard(_) => {
166                let left_dist = r2l
167                    .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
168                left = left.to_distributed_with_required(&Order::any(), &left_dist)?;
169            }
170            Distribution::UpstreamHashShard(_, _) => {
171                left = left.to_distributed_with_required(
172                    &Order::any(),
173                    &RequiredDist::shard_by_key(
174                        self.left().schema().len(),
175                        &self.eq_join_predicate().left_eq_indexes(),
176                    ),
177                )?;
178                let left_dist = left.distribution();
179                match left_dist {
180                    Distribution::HashShard(_) => {
181                        let right_dist = l2r.rewrite_required_distribution(
182                            &RequiredDist::PhysicalDist(left_dist.clone()),
183                        );
184                        right = right_dist.batch_enforce_if_not_satisfies(right, &Order::any())?
185                    }
186                    Distribution::UpstreamHashShard(_, _) => {
187                        left =
188                            RequiredDist::hash_shard(&self.eq_join_predicate().left_eq_indexes())
189                                .batch_enforce_if_not_satisfies(left, &Order::any())?;
190                        right =
191                            RequiredDist::hash_shard(&self.eq_join_predicate().right_eq_indexes())
192                                .batch_enforce_if_not_satisfies(right, &Order::any())?;
193                    }
194                    _ => unreachable!(),
195                }
196            }
197            _ => unreachable!(),
198        }
199
200        Ok(self.clone_with_left_right(left, right).into())
201    }
202}
203
204impl ToBatchPb for BatchHashJoin {
205    fn to_batch_prost_body(&self) -> NodeBody {
206        let eq_join_predicate = self.eq_join_predicate();
207        NodeBody::HashJoin(HashJoinNode {
208            join_type: self.core.join_type as i32,
209            left_key: self
210                .eq_join_predicate()
211                .left_eq_indexes()
212                .into_iter()
213                .map(|a| a as i32)
214                .collect(),
215            right_key: self
216                .eq_join_predicate()
217                .right_eq_indexes()
218                .into_iter()
219                .map(|a| a as i32)
220                .collect(),
221            null_safe: eq_join_predicate.null_safes().into_iter().collect(),
222            condition: self
223                .eq_join_predicate()
224                .other_cond()
225                .as_expr_unless_true()
226                .map(|x| x.to_expr_proto()),
227            output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
228            asof_desc: self.asof_desc,
229        })
230    }
231}
232
233impl ToLocalBatch for BatchHashJoin {
234    fn to_local(&self) -> Result<PlanRef> {
235        let right = RequiredDist::single()
236            .batch_enforce_if_not_satisfies(self.right().to_local()?, &Order::any())?;
237        let left = RequiredDist::single()
238            .batch_enforce_if_not_satisfies(self.left().to_local()?, &Order::any())?;
239
240        Ok(self.clone_with_left_right(left, right).into())
241    }
242}
243
244impl ExprRewritable<Batch> for BatchHashJoin {
245    fn has_rewritable_expr(&self) -> bool {
246        true
247    }
248
249    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
250        let mut core = self.core.clone();
251        core.rewrite_exprs(r);
252        let eq_join_predicate = core
253            .on
254            .as_eq_predicate_ref()
255            .expect("BatchHashJoin should store predicate as EqJoinPredicate");
256        let desc = self.asof_desc.map(|_| {
257            LogicalJoin::get_inequality_desc_from_predicate(
258                eq_join_predicate.other_cond().clone(),
259                core.left.schema().len(),
260            )
261            .unwrap()
262        });
263        Self::new(core, desc).into()
264    }
265}
266
267impl ExprVisitable for BatchHashJoin {
268    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
269        self.core.visit_exprs(v);
270    }
271}