risingwave_frontend/optimizer/plan_node/
batch_hash_join.rs1use pretty_xmlish::{Pretty, XmlNode};
16use risingwave_pb::batch_plan::HashJoinNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType};
19
20use super::batch::prelude::*;
21use super::utils::{Distill, childless_record};
22use super::{
23 BatchPlanRef as PlanRef, EqJoinPredicate, ExprRewritable, LogicalJoin, PlanBase,
24 PlanTreeNodeBinary, ToBatchPb, ToDistributedBatch, generic,
25};
26use crate::error::Result;
27use crate::expr::{Expr, ExprRewriter, ExprVisitor};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::utils::IndicesDisplay;
30use crate::optimizer::plan_node::{EqJoinPredicateDisplay, ToLocalBatch};
31use crate::optimizer::property::{Distribution, Order, RequiredDist};
32use crate::utils::ColIndexMappingRewriteExt;
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct BatchHashJoin {
39 pub base: PlanBase<Batch>,
40 core: generic::Join<PlanRef>,
41 asof_desc: Option<AsOfJoinDesc>,
43}
44
45impl BatchHashJoin {
46 pub fn new(core: generic::Join<PlanRef>, asof_desc: Option<AsOfJoinDesc>) -> Self {
47 core.on
48 .as_eq_predicate_ref()
49 .expect("BatchHashJoin requires JoinOn::EqPredicate in core");
50 let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core);
51 let base = PlanBase::new_batch_with_core(&core, dist, Order::any());
52
53 Self {
54 base,
55 core,
56 asof_desc,
57 }
58 }
59
60 pub(super) fn derive_dist(
61 left: &Distribution,
62 right: &Distribution,
63 join: &generic::Join<PlanRef>,
64 ) -> Distribution {
65 match (left, right) {
66 (Distribution::Single, Distribution::Single) => Distribution::Single,
67 (Distribution::HashShard(_), Distribution::HashShard(_)) => match join.join_type {
70 JoinType::Unspecified => {
71 unreachable!()
72 }
73 JoinType::FullOuter => Distribution::SomeShard,
74 JoinType::Inner
75 | JoinType::LeftOuter
76 | JoinType::LeftSemi
77 | JoinType::LeftAnti
78 | JoinType::AsofInner
79 | JoinType::AsofLeftOuter => {
80 let l2o = join.l2i_col_mapping().composite(&join.i2o_col_mapping());
81 l2o.rewrite_provided_distribution(left)
82 }
83 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
84 let r2o = join.r2i_col_mapping().composite(&join.i2o_col_mapping());
85 r2o.rewrite_provided_distribution(right)
86 }
87 },
88 (_, _) => unreachable!(
89 "suspicious distribution: left: {:?}, right: {:?}",
90 left, right
91 ),
92 }
93 }
94
95 pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
97 self.core
98 .on
99 .as_eq_predicate_ref()
100 .expect("BatchHashJoin should store predicate as EqJoinPredicate")
101 }
102}
103
104impl Distill for BatchHashJoin {
105 fn distill<'a>(&self) -> XmlNode<'a> {
106 let verbose = self.base.ctx().is_explain_verbose();
107 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
108 vec.push(("type", Pretty::debug(&self.core.join_type)));
109
110 let concat_schema = self.core.concat_schema();
111 vec.push((
112 "predicate",
113 Pretty::debug(&EqJoinPredicateDisplay {
114 eq_join_predicate: self.eq_join_predicate(),
115 input_schema: &concat_schema,
116 }),
117 ));
118 if verbose {
119 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
120 vec.push(("output", data));
121 }
122 childless_record("BatchHashJoin", vec)
123 }
124}
125
126impl PlanTreeNodeBinary<Batch> for BatchHashJoin {
127 fn left(&self) -> PlanRef {
128 self.core.left.clone()
129 }
130
131 fn right(&self) -> PlanRef {
132 self.core.right.clone()
133 }
134
135 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
136 let mut core = self.core.clone();
137 core.left = left;
138 core.right = right;
139 Self::new(core, self.asof_desc)
140 }
141}
142
143impl_plan_tree_node_for_binary! { Batch, BatchHashJoin }
144
145impl ToDistributedBatch for BatchHashJoin {
146 fn to_distributed(&self) -> Result<PlanRef> {
147 let mut right = self.right().to_distributed_with_required(
148 &Order::any(),
149 &RequiredDist::shard_by_key(
150 self.right().schema().len(),
151 &self.eq_join_predicate().right_eq_indexes(),
152 ),
153 )?;
154 let mut left = self.left();
155
156 let r2l = self
157 .eq_join_predicate()
158 .r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
159 let l2r = self
160 .eq_join_predicate()
161 .l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
162
163 let right_dist = right.distribution();
164 match right_dist {
165 Distribution::HashShard(_) => {
166 let left_dist = r2l
167 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
168 left = left.to_distributed_with_required(&Order::any(), &left_dist)?;
169 }
170 Distribution::UpstreamHashShard(_, _) => {
171 left = left.to_distributed_with_required(
172 &Order::any(),
173 &RequiredDist::shard_by_key(
174 self.left().schema().len(),
175 &self.eq_join_predicate().left_eq_indexes(),
176 ),
177 )?;
178 let left_dist = left.distribution();
179 match left_dist {
180 Distribution::HashShard(_) => {
181 let right_dist = l2r.rewrite_required_distribution(
182 &RequiredDist::PhysicalDist(left_dist.clone()),
183 );
184 right = right_dist.batch_enforce_if_not_satisfies(right, &Order::any())?
185 }
186 Distribution::UpstreamHashShard(_, _) => {
187 left =
188 RequiredDist::hash_shard(&self.eq_join_predicate().left_eq_indexes())
189 .batch_enforce_if_not_satisfies(left, &Order::any())?;
190 right =
191 RequiredDist::hash_shard(&self.eq_join_predicate().right_eq_indexes())
192 .batch_enforce_if_not_satisfies(right, &Order::any())?;
193 }
194 _ => unreachable!(),
195 }
196 }
197 _ => unreachable!(),
198 }
199
200 Ok(self.clone_with_left_right(left, right).into())
201 }
202}
203
204impl ToBatchPb for BatchHashJoin {
205 fn to_batch_prost_body(&self) -> NodeBody {
206 let eq_join_predicate = self.eq_join_predicate();
207 NodeBody::HashJoin(HashJoinNode {
208 join_type: self.core.join_type as i32,
209 left_key: self
210 .eq_join_predicate()
211 .left_eq_indexes()
212 .into_iter()
213 .map(|a| a as i32)
214 .collect(),
215 right_key: self
216 .eq_join_predicate()
217 .right_eq_indexes()
218 .into_iter()
219 .map(|a| a as i32)
220 .collect(),
221 null_safe: eq_join_predicate.null_safes().into_iter().collect(),
222 condition: self
223 .eq_join_predicate()
224 .other_cond()
225 .as_expr_unless_true()
226 .map(|x| x.to_expr_proto()),
227 output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
228 asof_desc: self.asof_desc,
229 })
230 }
231}
232
233impl ToLocalBatch for BatchHashJoin {
234 fn to_local(&self) -> Result<PlanRef> {
235 let right = RequiredDist::single()
236 .batch_enforce_if_not_satisfies(self.right().to_local()?, &Order::any())?;
237 let left = RequiredDist::single()
238 .batch_enforce_if_not_satisfies(self.left().to_local()?, &Order::any())?;
239
240 Ok(self.clone_with_left_right(left, right).into())
241 }
242}
243
244impl ExprRewritable<Batch> for BatchHashJoin {
245 fn has_rewritable_expr(&self) -> bool {
246 true
247 }
248
249 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
250 let mut core = self.core.clone();
251 core.rewrite_exprs(r);
252 let eq_join_predicate = core
253 .on
254 .as_eq_predicate_ref()
255 .expect("BatchHashJoin should store predicate as EqJoinPredicate");
256 let desc = self.asof_desc.map(|_| {
257 LogicalJoin::get_inequality_desc_from_predicate(
258 eq_join_predicate.other_cond().clone(),
259 core.left.schema().len(),
260 )
261 .unwrap()
262 });
263 Self::new(core, desc).into()
264 }
265}
266
267impl ExprVisitable for BatchHashJoin {
268 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
269 self.core.visit_exprs(v);
270 }
271}