risingwave_frontend/optimizer/plan_node/
stream_hash_join.rs1use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::util::functional::SameOrElseExt;
18use risingwave_pb::plan_common::JoinType;
19use risingwave_pb::stream_plan::stream_node::NodeBody;
20use risingwave_pb::stream_plan::{DeltaExpression, HashJoinNode, PbInequalityPair};
21
22use super::generic::{GenericPlanNode, Join};
23use super::stream::prelude::*;
24use super::stream_join_common::StreamJoinCommon;
25use super::utils::{Distill, childless_record, plan_node_name, watermark_pretty};
26use super::{
27 ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode, generic,
28};
29use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputPair};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::utils::IndicesDisplay;
32use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay};
33use crate::optimizer::property::{MonotonicityMap, WatermarkColumns};
34use crate::stream_fragmenter::BuildFragmentGraphState;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct StreamHashJoin {
41 pub base: PlanBase<Stream>,
42 core: generic::Join<PlanRef>,
43
44 eq_join_predicate: EqJoinPredicate,
47
48 inequality_pairs: Vec<(bool, InequalityInputPair)>,
51
52 is_append_only: bool,
55
56 clean_left_state_conjunction_idx: Option<usize>,
60 clean_right_state_conjunction_idx: Option<usize>,
64}
65
66impl StreamHashJoin {
67 pub fn new(core: generic::Join<PlanRef>, eq_join_predicate: EqJoinPredicate) -> Self {
68 let ctx = core.ctx();
69
70 let append_only = match core.join_type {
72 JoinType::Inner => core.left.append_only() && core.right.append_only(),
73 _ => false,
74 };
75
76 let dist = StreamJoinCommon::derive_dist(
77 core.left.distribution(),
78 core.right.distribution(),
79 &core,
80 );
81
82 let mut inequality_pairs = vec![];
83 let mut clean_left_state_conjunction_idx = None;
84 let mut clean_right_state_conjunction_idx = None;
85
86 let mut reorder_idx = vec![];
88 for (i, (left_key, right_key)) in eq_join_predicate.eq_indexes().iter().enumerate() {
89 if core.left.watermark_columns().contains(*left_key)
90 && core.right.watermark_columns().contains(*right_key)
91 {
92 reorder_idx.push(i);
93 }
94 }
95 let eq_join_predicate = eq_join_predicate.reorder(&reorder_idx);
96
97 let watermark_columns = {
98 let l2i = core.l2i_col_mapping();
99 let r2i = core.r2i_col_mapping();
100
101 let mut equal_condition_clean_state = false;
102 let mut watermark_columns = WatermarkColumns::new();
103 for (left_key, right_key) in eq_join_predicate.eq_indexes() {
104 if let Some(l_wtmk_group) = core.left.watermark_columns().get_group(left_key)
105 && let Some(r_wtmk_group) = core.right.watermark_columns().get_group(right_key)
106 {
107 equal_condition_clean_state = true;
108 if let Some(internal) = l2i.try_map(left_key) {
109 watermark_columns.insert(
110 internal,
111 l_wtmk_group
112 .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
113 );
114 }
115 if let Some(internal) = r2i.try_map(right_key) {
116 watermark_columns.insert(
117 internal,
118 l_wtmk_group
119 .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
120 );
121 }
122 }
123 }
124 let (left_cols_num, original_inequality_pairs) = eq_join_predicate.inequality_pairs();
125 for (
126 conjunction_idx,
127 InequalityInputPair {
128 key_required_larger,
129 key_required_smaller,
130 delta_expression,
131 },
132 ) in original_inequality_pairs
133 {
134 let both_upstream_has_watermark = if key_required_larger < key_required_smaller {
135 core.left.watermark_columns().contains(key_required_larger)
136 && core
137 .right
138 .watermark_columns()
139 .contains(key_required_smaller - left_cols_num)
140 } else {
141 core.left.watermark_columns().contains(key_required_smaller)
142 && core
143 .right
144 .watermark_columns()
145 .contains(key_required_larger - left_cols_num)
146 };
147 if !both_upstream_has_watermark {
148 continue;
149 }
150
151 let (internal_col1, internal_col2, do_state_cleaning) =
152 if key_required_larger < key_required_smaller {
153 (
154 l2i.try_map(key_required_larger),
155 r2i.try_map(key_required_smaller - left_cols_num),
156 if !equal_condition_clean_state
157 && clean_left_state_conjunction_idx.is_none()
158 {
159 clean_left_state_conjunction_idx = Some(conjunction_idx);
160 true
161 } else {
162 false
163 },
164 )
165 } else {
166 (
167 r2i.try_map(key_required_larger - left_cols_num),
168 l2i.try_map(key_required_smaller),
169 if !equal_condition_clean_state
170 && clean_right_state_conjunction_idx.is_none()
171 {
172 clean_right_state_conjunction_idx = Some(conjunction_idx);
173 true
174 } else {
175 false
176 },
177 )
178 };
179 let mut is_valuable_inequality = do_state_cleaning;
180 if let Some(internal) = internal_col1
181 && !watermark_columns.contains(internal)
182 {
183 watermark_columns.insert(internal, ctx.next_watermark_group_id());
184 is_valuable_inequality = true;
185 }
186 if let Some(internal) = internal_col2
187 && !watermark_columns.contains(internal)
188 {
189 watermark_columns.insert(internal, ctx.next_watermark_group_id());
190 }
191 if is_valuable_inequality {
192 inequality_pairs.push((
193 do_state_cleaning,
194 InequalityInputPair {
195 key_required_larger,
196 key_required_smaller,
197 delta_expression,
198 },
199 ));
200 }
201 }
202 watermark_columns.map_clone(&core.i2o_col_mapping())
203 };
204
205 let base = PlanBase::new_stream_with_core(
207 &core,
208 dist,
209 append_only,
210 false, watermark_columns,
212 MonotonicityMap::new(), );
214
215 Self {
216 base,
217 core,
218 eq_join_predicate,
219 inequality_pairs,
220 is_append_only: append_only,
221 clean_left_state_conjunction_idx,
222 clean_right_state_conjunction_idx,
223 }
224 }
225
226 pub fn join_type(&self) -> JoinType {
228 self.core.join_type
229 }
230
231 pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
233 &self.eq_join_predicate
234 }
235
236 pub fn into_delta_join(self) -> StreamDeltaJoin {
238 StreamDeltaJoin::new(self.core, self.eq_join_predicate)
239 }
240
241 pub fn derive_dist_key_in_join_key(&self) -> Vec<usize> {
242 let left_dk_indices = self.left().distribution().dist_column_indices().to_vec();
243 let right_dk_indices = self.right().distribution().dist_column_indices().to_vec();
244
245 StreamJoinCommon::get_dist_key_in_join_key(
246 &left_dk_indices,
247 &right_dk_indices,
248 self.eq_join_predicate(),
249 )
250 }
251
252 pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> {
253 &self.inequality_pairs
254 }
255}
256
257impl Distill for StreamHashJoin {
258 fn distill<'a>(&self) -> XmlNode<'a> {
259 let (ljk, rjk) = self
260 .eq_join_predicate
261 .eq_indexes()
262 .first()
263 .cloned()
264 .expect("first join key");
265
266 let name = plan_node_name!("StreamHashJoin",
267 { "window", self.left().watermark_columns().contains(ljk) && self.right().watermark_columns().contains(rjk) },
268 { "interval", self.clean_left_state_conjunction_idx.is_some() && self.clean_right_state_conjunction_idx.is_some() },
269 { "append_only", self.is_append_only },
270 );
271 let verbose = self.base.ctx().is_explain_verbose();
272 let mut vec = Vec::with_capacity(6);
273 vec.push(("type", Pretty::debug(&self.core.join_type)));
274
275 let concat_schema = self.core.concat_schema();
276 vec.push((
277 "predicate",
278 Pretty::debug(&EqJoinPredicateDisplay {
279 eq_join_predicate: self.eq_join_predicate(),
280 input_schema: &concat_schema,
281 }),
282 ));
283
284 let get_cond = |conjunction_idx| {
285 Pretty::debug(&ExprDisplay {
286 expr: &self.eq_join_predicate().other_cond().conjunctions[conjunction_idx],
287 input_schema: &concat_schema,
288 })
289 };
290 if let Some(i) = self.clean_left_state_conjunction_idx {
291 vec.push(("conditions_to_clean_left_state_table", get_cond(i)));
292 }
293 if let Some(i) = self.clean_right_state_conjunction_idx {
294 vec.push(("conditions_to_clean_right_state_table", get_cond(i)));
295 }
296 if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) {
297 vec.push(("output_watermarks", ow));
298 }
299
300 if verbose {
301 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
302 vec.push(("output", data));
303 }
304
305 childless_record(name, vec)
306 }
307}
308
309impl PlanTreeNodeBinary for StreamHashJoin {
310 fn left(&self) -> PlanRef {
311 self.core.left.clone()
312 }
313
314 fn right(&self) -> PlanRef {
315 self.core.right.clone()
316 }
317
318 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
319 let mut core = self.core.clone();
320 core.left = left;
321 core.right = right;
322 Self::new(core, self.eq_join_predicate.clone())
323 }
324}
325
326impl_plan_tree_node_for_binary! { StreamHashJoin }
327
328impl StreamNode for StreamHashJoin {
329 fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> NodeBody {
330 let left_jk_indices = self.eq_join_predicate.left_eq_indexes();
331 let right_jk_indices = self.eq_join_predicate.right_eq_indexes();
332 let left_jk_indices_prost = left_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
333 let right_jk_indices_prost = right_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
334
335 let dk_indices_in_jk = self.derive_dist_key_in_join_key();
336
337 let (left_table, left_degree_table, left_deduped_input_pk_indices) =
338 Join::infer_internal_and_degree_table_catalog(
339 self.left().plan_base(),
340 left_jk_indices,
341 dk_indices_in_jk.clone(),
342 );
343 let (right_table, right_degree_table, right_deduped_input_pk_indices) =
344 Join::infer_internal_and_degree_table_catalog(
345 self.right().plan_base(),
346 right_jk_indices,
347 dk_indices_in_jk,
348 );
349
350 let left_deduped_input_pk_indices = left_deduped_input_pk_indices
351 .iter()
352 .map(|idx| *idx as u32)
353 .collect_vec();
354
355 let right_deduped_input_pk_indices = right_deduped_input_pk_indices
356 .iter()
357 .map(|idx| *idx as u32)
358 .collect_vec();
359
360 let (left_table, left_degree_table) = (
361 left_table.with_id(state.gen_table_id_wrapped()),
362 left_degree_table.with_id(state.gen_table_id_wrapped()),
363 );
364 let (right_table, right_degree_table) = (
365 right_table.with_id(state.gen_table_id_wrapped()),
366 right_degree_table.with_id(state.gen_table_id_wrapped()),
367 );
368
369 let null_safe_prost = self.eq_join_predicate.null_safes().into_iter().collect();
370
371 NodeBody::HashJoin(Box::new(HashJoinNode {
372 join_type: self.core.join_type as i32,
373 left_key: left_jk_indices_prost,
374 right_key: right_jk_indices_prost,
375 null_safe: null_safe_prost,
376 condition: self
377 .eq_join_predicate
378 .other_cond()
379 .as_expr_unless_true()
380 .map(|x| x.to_expr_proto()),
381 inequality_pairs: self
382 .inequality_pairs
383 .iter()
384 .map(
385 |(
386 do_state_clean,
387 InequalityInputPair {
388 key_required_larger,
389 key_required_smaller,
390 delta_expression,
391 },
392 )| {
393 PbInequalityPair {
394 key_required_larger: *key_required_larger as u32,
395 key_required_smaller: *key_required_smaller as u32,
396 clean_state: *do_state_clean,
397 delta_expression: delta_expression.as_ref().map(
398 |(delta_type, delta)| DeltaExpression {
399 delta_type: *delta_type as i32,
400 delta: Some(delta.to_expr_proto()),
401 },
402 ),
403 }
404 },
405 )
406 .collect_vec(),
407 left_table: Some(left_table.to_internal_table_prost()),
408 right_table: Some(right_table.to_internal_table_prost()),
409 left_degree_table: Some(left_degree_table.to_internal_table_prost()),
410 right_degree_table: Some(right_degree_table.to_internal_table_prost()),
411 left_deduped_input_pk_indices,
412 right_deduped_input_pk_indices,
413 output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
414 is_append_only: self.is_append_only,
415 }))
416 }
417}
418
419impl ExprRewritable for StreamHashJoin {
420 fn has_rewritable_expr(&self) -> bool {
421 true
422 }
423
424 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
425 let mut core = self.core.clone();
426 core.rewrite_exprs(r);
427 Self::new(core, self.eq_join_predicate.rewrite_exprs(r)).into()
428 }
429}
430
431impl ExprVisitable for StreamHashJoin {
432 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
433 self.core.visit_exprs(v);
434 }
435}