risingwave_frontend/optimizer/plan_node/
stream_hash_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::session_config::join_encoding_type::JoinEncodingType;
18use risingwave_common::util::functional::SameOrElseExt;
19use risingwave_pb::plan_common::JoinType;
20use risingwave_pb::stream_plan::stream_node::NodeBody;
21use risingwave_pb::stream_plan::{DeltaExpression, HashJoinNode, PbInequalityPair};
22
23use super::generic::{GenericPlanNode, Join};
24use super::stream::prelude::*;
25use super::stream_join_common::StreamJoinCommon;
26use super::utils::{Distill, childless_record, plan_node_name, watermark_pretty};
27use super::{
28    ExprRewritable, PlanBase, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode,
29    StreamPlanRef as PlanRef, generic,
30};
31use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputPair};
32use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
33use crate::optimizer::plan_node::utils::IndicesDisplay;
34use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay};
35use crate::optimizer::property::{MonotonicityMap, WatermarkColumns};
36use crate::stream_fragmenter::BuildFragmentGraphState;
37
38/// [`StreamHashJoin`] implements [`super::LogicalJoin`] with hash table. It builds a hash table
39/// from inner (right-side) relation and probes with data from outer (left-side) relation to
40/// get output rows.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct StreamHashJoin {
43    pub base: PlanBase<Stream>,
44    core: generic::Join<PlanRef>,
45
46    /// The join condition must be equivalent to `logical.on`, but separated into equal and
47    /// non-equal parts to facilitate execution later
48    eq_join_predicate: EqJoinPredicate,
49
50    /// `(do_state_cleaning, InequalityInputPair {key_required_larger, key_required_smaller,
51    /// delta_expression})`. View struct `InequalityInputPair` for details.
52    inequality_pairs: Vec<(bool, InequalityInputPair)>,
53
54    /// Whether can optimize for append-only stream.
55    /// It is true if input of both side is append-only
56    is_append_only: bool,
57
58    /// The conjunction index of the inequality which is used to clean left state table in
59    /// `HashJoinExecutor`. If any equal condition is able to clean state table, this field
60    /// will always be `None`.
61    clean_left_state_conjunction_idx: Option<usize>,
62    /// The conjunction index of the inequality which is used to clean right state table in
63    /// `HashJoinExecutor`. If any equal condition is able to clean state table, this field
64    /// will always be `None`.
65    clean_right_state_conjunction_idx: Option<usize>,
66
67    /// Determine which encoding will be used to encode join rows in operator cache.
68    join_encoding_type: JoinEncodingType,
69}
70
71impl StreamHashJoin {
72    pub fn new(core: generic::Join<PlanRef>, eq_join_predicate: EqJoinPredicate) -> Result<Self> {
73        let ctx = core.ctx();
74
75        let stream_kind = core.stream_kind()?;
76
77        let dist = StreamJoinCommon::derive_dist(
78            core.left.distribution(),
79            core.right.distribution(),
80            &core,
81        );
82
83        let mut inequality_pairs = vec![];
84        let mut clean_left_state_conjunction_idx = None;
85        let mut clean_right_state_conjunction_idx = None;
86
87        // Reorder `eq_join_predicate` by placing the watermark column at the beginning.
88        let mut reorder_idx = vec![];
89        for (i, (left_key, right_key)) in eq_join_predicate.eq_indexes().iter().enumerate() {
90            if core.left.watermark_columns().contains(*left_key)
91                && core.right.watermark_columns().contains(*right_key)
92            {
93                reorder_idx.push(i);
94            }
95        }
96        let eq_join_predicate = eq_join_predicate.reorder(&reorder_idx);
97
98        let watermark_columns = {
99            let l2i = core.l2i_col_mapping();
100            let r2i = core.r2i_col_mapping();
101
102            let mut equal_condition_clean_state = false;
103            let mut watermark_columns = WatermarkColumns::new();
104            for (left_key, right_key) in eq_join_predicate.eq_indexes() {
105                if let Some(l_wtmk_group) = core.left.watermark_columns().get_group(left_key)
106                    && let Some(r_wtmk_group) = core.right.watermark_columns().get_group(right_key)
107                {
108                    equal_condition_clean_state = true;
109                    if let Some(internal) = l2i.try_map(left_key) {
110                        watermark_columns.insert(
111                            internal,
112                            l_wtmk_group
113                                .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
114                        );
115                    }
116                    if let Some(internal) = r2i.try_map(right_key) {
117                        watermark_columns.insert(
118                            internal,
119                            l_wtmk_group
120                                .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
121                        );
122                    }
123                }
124            }
125            let (left_cols_num, original_inequality_pairs) = eq_join_predicate.inequality_pairs();
126            for (
127                conjunction_idx,
128                InequalityInputPair {
129                    key_required_larger,
130                    key_required_smaller,
131                    delta_expression,
132                },
133            ) in original_inequality_pairs
134            {
135                let both_upstream_has_watermark = if key_required_larger < key_required_smaller {
136                    core.left.watermark_columns().contains(key_required_larger)
137                        && core
138                            .right
139                            .watermark_columns()
140                            .contains(key_required_smaller - left_cols_num)
141                } else {
142                    core.left.watermark_columns().contains(key_required_smaller)
143                        && core
144                            .right
145                            .watermark_columns()
146                            .contains(key_required_larger - left_cols_num)
147                };
148                if !both_upstream_has_watermark {
149                    continue;
150                }
151
152                let (internal_col1, internal_col2, do_state_cleaning) =
153                    if key_required_larger < key_required_smaller {
154                        (
155                            l2i.try_map(key_required_larger),
156                            r2i.try_map(key_required_smaller - left_cols_num),
157                            if !equal_condition_clean_state
158                                && clean_left_state_conjunction_idx.is_none()
159                            {
160                                clean_left_state_conjunction_idx = Some(conjunction_idx);
161                                true
162                            } else {
163                                false
164                            },
165                        )
166                    } else {
167                        (
168                            r2i.try_map(key_required_larger - left_cols_num),
169                            l2i.try_map(key_required_smaller),
170                            if !equal_condition_clean_state
171                                && clean_right_state_conjunction_idx.is_none()
172                            {
173                                clean_right_state_conjunction_idx = Some(conjunction_idx);
174                                true
175                            } else {
176                                false
177                            },
178                        )
179                    };
180                let mut is_valuable_inequality = do_state_cleaning;
181                if let Some(internal) = internal_col1
182                    && !watermark_columns.contains(internal)
183                {
184                    watermark_columns.insert(internal, ctx.next_watermark_group_id());
185                    is_valuable_inequality = true;
186                }
187                if let Some(internal) = internal_col2
188                    && !watermark_columns.contains(internal)
189                {
190                    watermark_columns.insert(internal, ctx.next_watermark_group_id());
191                }
192                if is_valuable_inequality {
193                    inequality_pairs.push((
194                        do_state_cleaning,
195                        InequalityInputPair {
196                            key_required_larger,
197                            key_required_smaller,
198                            delta_expression,
199                        },
200                    ));
201                }
202            }
203            watermark_columns.map_clone(&core.i2o_col_mapping())
204        };
205
206        // TODO: derive from input
207        let base = PlanBase::new_stream_with_core(
208            &core,
209            dist,
210            stream_kind,
211            false, // TODO(rc): derive EOWC property from input
212            watermark_columns,
213            MonotonicityMap::new(), // TODO: derive monotonicity
214        );
215
216        Ok(Self {
217            base,
218            core,
219            eq_join_predicate,
220            inequality_pairs,
221            is_append_only: stream_kind.is_append_only(),
222            clean_left_state_conjunction_idx,
223            clean_right_state_conjunction_idx,
224            join_encoding_type: ctx.session_ctx().config().streaming_join_encoding(),
225        })
226    }
227
228    /// Get join type
229    pub fn join_type(&self) -> JoinType {
230        self.core.join_type
231    }
232
233    /// Get a reference to the hash join's eq join predicate.
234    pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
235        &self.eq_join_predicate
236    }
237
238    /// Convert this hash join to a delta join plan
239    pub fn into_delta_join(self) -> StreamDeltaJoin {
240        StreamDeltaJoin::new(self.core, self.eq_join_predicate).unwrap()
241    }
242
243    pub fn derive_dist_key_in_join_key(&self) -> Vec<usize> {
244        let left_dk_indices = self.left().distribution().dist_column_indices().to_vec();
245        let right_dk_indices = self.right().distribution().dist_column_indices().to_vec();
246
247        StreamJoinCommon::get_dist_key_in_join_key(
248            &left_dk_indices,
249            &right_dk_indices,
250            self.eq_join_predicate(),
251        )
252    }
253
254    pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> {
255        &self.inequality_pairs
256    }
257}
258
259impl Distill for StreamHashJoin {
260    fn distill<'a>(&self) -> XmlNode<'a> {
261        let (ljk, rjk) = self
262            .eq_join_predicate
263            .eq_indexes()
264            .first()
265            .cloned()
266            .expect("first join key");
267
268        let name = plan_node_name!("StreamHashJoin",
269            { "window", self.left().watermark_columns().contains(ljk) && self.right().watermark_columns().contains(rjk) },
270            { "interval", self.clean_left_state_conjunction_idx.is_some() && self.clean_right_state_conjunction_idx.is_some() },
271            { "append_only", self.is_append_only },
272        );
273        let verbose = self.base.ctx().is_explain_verbose();
274        let mut vec = Vec::with_capacity(6);
275        vec.push(("type", Pretty::debug(&self.core.join_type)));
276
277        let concat_schema = self.core.concat_schema();
278        vec.push((
279            "predicate",
280            Pretty::debug(&EqJoinPredicateDisplay {
281                eq_join_predicate: self.eq_join_predicate(),
282                input_schema: &concat_schema,
283            }),
284        ));
285
286        let get_cond = |conjunction_idx| {
287            Pretty::debug(&ExprDisplay {
288                expr: &self.eq_join_predicate().other_cond().conjunctions[conjunction_idx],
289                input_schema: &concat_schema,
290            })
291        };
292        if let Some(i) = self.clean_left_state_conjunction_idx {
293            vec.push(("conditions_to_clean_left_state_table", get_cond(i)));
294        }
295        if let Some(i) = self.clean_right_state_conjunction_idx {
296            vec.push(("conditions_to_clean_right_state_table", get_cond(i)));
297        }
298        if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) {
299            vec.push(("output_watermarks", ow));
300        }
301
302        if verbose {
303            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
304            vec.push(("output", data));
305        }
306
307        childless_record(name, vec)
308    }
309}
310
311impl PlanTreeNodeBinary<Stream> for StreamHashJoin {
312    fn left(&self) -> PlanRef {
313        self.core.left.clone()
314    }
315
316    fn right(&self) -> PlanRef {
317        self.core.right.clone()
318    }
319
320    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
321        let mut core = self.core.clone();
322        core.left = left;
323        core.right = right;
324        Self::new(core, self.eq_join_predicate.clone()).unwrap()
325    }
326}
327
328impl_plan_tree_node_for_binary! { Stream, StreamHashJoin }
329
330impl StreamNode for StreamHashJoin {
331    fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> NodeBody {
332        let left_jk_indices = self.eq_join_predicate.left_eq_indexes();
333        let right_jk_indices = self.eq_join_predicate.right_eq_indexes();
334        let left_jk_indices_prost = left_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
335        let right_jk_indices_prost = right_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
336
337        let dk_indices_in_jk = self.derive_dist_key_in_join_key();
338
339        let (left_table, left_degree_table, left_deduped_input_pk_indices) =
340            Join::infer_internal_and_degree_table_catalog(
341                self.left(),
342                left_jk_indices,
343                dk_indices_in_jk.clone(),
344            );
345        let (right_table, right_degree_table, right_deduped_input_pk_indices) =
346            Join::infer_internal_and_degree_table_catalog(
347                self.right(),
348                right_jk_indices,
349                dk_indices_in_jk,
350            );
351
352        let left_deduped_input_pk_indices = left_deduped_input_pk_indices
353            .iter()
354            .map(|idx| *idx as u32)
355            .collect_vec();
356
357        let right_deduped_input_pk_indices = right_deduped_input_pk_indices
358            .iter()
359            .map(|idx| *idx as u32)
360            .collect_vec();
361
362        let (left_table, left_degree_table) = (
363            left_table.with_id(state.gen_table_id_wrapped()),
364            left_degree_table.with_id(state.gen_table_id_wrapped()),
365        );
366        let (right_table, right_degree_table) = (
367            right_table.with_id(state.gen_table_id_wrapped()),
368            right_degree_table.with_id(state.gen_table_id_wrapped()),
369        );
370
371        let null_safe_prost = self.eq_join_predicate.null_safes().into_iter().collect();
372
373        NodeBody::HashJoin(Box::new(HashJoinNode {
374            join_type: self.core.join_type as i32,
375            left_key: left_jk_indices_prost,
376            right_key: right_jk_indices_prost,
377            null_safe: null_safe_prost,
378            condition: self
379                .eq_join_predicate
380                .other_cond()
381                .as_expr_unless_true()
382                .map(|x| x.to_expr_proto()),
383            inequality_pairs: self
384                .inequality_pairs
385                .iter()
386                .map(
387                    |(
388                        do_state_clean,
389                        InequalityInputPair {
390                            key_required_larger,
391                            key_required_smaller,
392                            delta_expression,
393                        },
394                    )| {
395                        PbInequalityPair {
396                            key_required_larger: *key_required_larger as u32,
397                            key_required_smaller: *key_required_smaller as u32,
398                            clean_state: *do_state_clean,
399                            delta_expression: delta_expression.as_ref().map(
400                                |(delta_type, delta)| DeltaExpression {
401                                    delta_type: *delta_type as i32,
402                                    delta: Some(delta.to_expr_proto()),
403                                },
404                            ),
405                        }
406                    },
407                )
408                .collect_vec(),
409            left_table: Some(left_table.to_internal_table_prost()),
410            right_table: Some(right_table.to_internal_table_prost()),
411            left_degree_table: Some(left_degree_table.to_internal_table_prost()),
412            right_degree_table: Some(right_degree_table.to_internal_table_prost()),
413            left_deduped_input_pk_indices,
414            right_deduped_input_pk_indices,
415            output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
416            is_append_only: self.is_append_only,
417            join_encoding_type: self.join_encoding_type as i32,
418        }))
419    }
420}
421
422impl ExprRewritable<Stream> for StreamHashJoin {
423    fn has_rewritable_expr(&self) -> bool {
424        true
425    }
426
427    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
428        let mut core = self.core.clone();
429        core.rewrite_exprs(r);
430        Self::new(core, self.eq_join_predicate.rewrite_exprs(r))
431            .unwrap()
432            .into()
433    }
434}
435
436impl ExprVisitable for StreamHashJoin {
437    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
438        self.core.visit_exprs(v);
439    }
440}