risingwave_frontend/optimizer/plan_node/
batch_hash_join.rs1use pretty_xmlish::{Pretty, XmlNode};
16use risingwave_pb::batch_plan::HashJoinNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType};
19
20use super::batch::prelude::*;
21use super::utils::{Distill, childless_record};
22use super::{
23 EqJoinPredicate, ExprRewritable, LogicalJoin, PlanBase, PlanRef, PlanTreeNodeBinary, ToBatchPb,
24 ToDistributedBatch, generic,
25};
26use crate::error::Result;
27use crate::expr::{Expr, ExprRewriter, ExprVisitor};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::utils::IndicesDisplay;
30use crate::optimizer::plan_node::{EqJoinPredicateDisplay, ToLocalBatch};
31use crate::optimizer::property::{Distribution, Order, RequiredDist};
32use crate::utils::ColIndexMappingRewriteExt;
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct BatchHashJoin {
39 pub base: PlanBase<Batch>,
40 core: generic::Join<PlanRef>,
41 eq_join_predicate: EqJoinPredicate,
44 asof_desc: Option<AsOfJoinDesc>,
46}
47
48impl BatchHashJoin {
49 pub fn new(
50 core: generic::Join<PlanRef>,
51 eq_join_predicate: EqJoinPredicate,
52 asof_desc: Option<AsOfJoinDesc>,
53 ) -> Self {
54 let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core);
55 let base = PlanBase::new_batch_with_core(&core, dist, Order::any());
56
57 Self {
58 base,
59 core,
60 eq_join_predicate,
61 asof_desc,
62 }
63 }
64
65 pub(super) fn derive_dist(
66 left: &Distribution,
67 right: &Distribution,
68 join: &generic::Join<PlanRef>,
69 ) -> Distribution {
70 match (left, right) {
71 (Distribution::Single, Distribution::Single) => Distribution::Single,
72 (Distribution::HashShard(_), Distribution::HashShard(_)) => match join.join_type {
75 JoinType::Unspecified => {
76 unreachable!()
77 }
78 JoinType::FullOuter => Distribution::SomeShard,
79 JoinType::Inner
80 | JoinType::LeftOuter
81 | JoinType::LeftSemi
82 | JoinType::LeftAnti
83 | JoinType::AsofInner
84 | JoinType::AsofLeftOuter => {
85 let l2o = join.l2i_col_mapping().composite(&join.i2o_col_mapping());
86 l2o.rewrite_provided_distribution(left)
87 }
88 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
89 let r2o = join.r2i_col_mapping().composite(&join.i2o_col_mapping());
90 r2o.rewrite_provided_distribution(right)
91 }
92 },
93 (_, _) => unreachable!(
94 "suspicious distribution: left: {:?}, right: {:?}",
95 left, right
96 ),
97 }
98 }
99
100 pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
102 &self.eq_join_predicate
103 }
104}
105
106impl Distill for BatchHashJoin {
107 fn distill<'a>(&self) -> XmlNode<'a> {
108 let verbose = self.base.ctx().is_explain_verbose();
109 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
110 vec.push(("type", Pretty::debug(&self.core.join_type)));
111
112 let concat_schema = self.core.concat_schema();
113 vec.push((
114 "predicate",
115 Pretty::debug(&EqJoinPredicateDisplay {
116 eq_join_predicate: self.eq_join_predicate(),
117 input_schema: &concat_schema,
118 }),
119 ));
120 if verbose {
121 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
122 vec.push(("output", data));
123 }
124 childless_record("BatchHashJoin", vec)
125 }
126}
127
128impl PlanTreeNodeBinary for BatchHashJoin {
129 fn left(&self) -> PlanRef {
130 self.core.left.clone()
131 }
132
133 fn right(&self) -> PlanRef {
134 self.core.right.clone()
135 }
136
137 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
138 let mut core = self.core.clone();
139 core.left = left;
140 core.right = right;
141 Self::new(core, self.eq_join_predicate.clone(), self.asof_desc)
142 }
143}
144
145impl_plan_tree_node_for_binary! { BatchHashJoin }
146
147impl ToDistributedBatch for BatchHashJoin {
148 fn to_distributed(&self) -> Result<PlanRef> {
149 let mut right = self.right().to_distributed_with_required(
150 &Order::any(),
151 &RequiredDist::shard_by_key(
152 self.right().schema().len(),
153 &self.eq_join_predicate().right_eq_indexes(),
154 ),
155 )?;
156 let mut left = self.left();
157
158 let r2l = self
159 .eq_join_predicate()
160 .r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
161 let l2r = self
162 .eq_join_predicate()
163 .l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
164
165 let right_dist = right.distribution();
166 match right_dist {
167 Distribution::HashShard(_) => {
168 let left_dist = r2l
169 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
170 left = left.to_distributed_with_required(&Order::any(), &left_dist)?;
171 }
172 Distribution::UpstreamHashShard(_, _) => {
173 left = left.to_distributed_with_required(
174 &Order::any(),
175 &RequiredDist::shard_by_key(
176 self.left().schema().len(),
177 &self.eq_join_predicate().left_eq_indexes(),
178 ),
179 )?;
180 let left_dist = left.distribution();
181 match left_dist {
182 Distribution::HashShard(_) => {
183 let right_dist = l2r.rewrite_required_distribution(
184 &RequiredDist::PhysicalDist(left_dist.clone()),
185 );
186 right = right_dist.enforce_if_not_satisfies(right, &Order::any())?
187 }
188 Distribution::UpstreamHashShard(_, _) => {
189 left =
190 RequiredDist::hash_shard(&self.eq_join_predicate().left_eq_indexes())
191 .enforce_if_not_satisfies(left, &Order::any())?;
192 right =
193 RequiredDist::hash_shard(&self.eq_join_predicate().right_eq_indexes())
194 .enforce_if_not_satisfies(right, &Order::any())?;
195 }
196 _ => unreachable!(),
197 }
198 }
199 _ => unreachable!(),
200 }
201
202 Ok(self.clone_with_left_right(left, right).into())
203 }
204}
205
206impl ToBatchPb for BatchHashJoin {
207 fn to_batch_prost_body(&self) -> NodeBody {
208 NodeBody::HashJoin(HashJoinNode {
209 join_type: self.core.join_type as i32,
210 left_key: self
211 .eq_join_predicate
212 .left_eq_indexes()
213 .into_iter()
214 .map(|a| a as i32)
215 .collect(),
216 right_key: self
217 .eq_join_predicate
218 .right_eq_indexes()
219 .into_iter()
220 .map(|a| a as i32)
221 .collect(),
222 null_safe: self.eq_join_predicate.null_safes().into_iter().collect(),
223 condition: self
224 .eq_join_predicate
225 .other_cond()
226 .as_expr_unless_true()
227 .map(|x| x.to_expr_proto()),
228 output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
229 asof_desc: self.asof_desc,
230 })
231 }
232}
233
234impl ToLocalBatch for BatchHashJoin {
235 fn to_local(&self) -> Result<PlanRef> {
236 let right = RequiredDist::single()
237 .enforce_if_not_satisfies(self.right().to_local()?, &Order::any())?;
238 let left = RequiredDist::single()
239 .enforce_if_not_satisfies(self.left().to_local()?, &Order::any())?;
240
241 Ok(self.clone_with_left_right(left, right).into())
242 }
243}
244
245impl ExprRewritable for BatchHashJoin {
246 fn has_rewritable_expr(&self) -> bool {
247 true
248 }
249
250 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
251 let mut core = self.core.clone();
252 core.rewrite_exprs(r);
253 let eq_join_predicate = self.eq_join_predicate.rewrite_exprs(r);
254 let desc = self.asof_desc.map(|_| {
255 LogicalJoin::get_inequality_desc_from_predicate(
256 eq_join_predicate.other_cond().clone(),
257 core.left.schema().len(),
258 )
259 .unwrap()
260 });
261 Self::new(core, eq_join_predicate, desc).into()
262 }
263}
264
265impl ExprVisitable for BatchHashJoin {
266 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
267 self.core.visit_exprs(v);
268 }
269}