risingwave_frontend/optimizer/plan_node/
stream_hash_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::util::functional::SameOrElseExt;
18use risingwave_pb::plan_common::JoinType;
19use risingwave_pb::stream_plan::stream_node::NodeBody;
20use risingwave_pb::stream_plan::{
21    DeltaExpression, HashJoinNode, PbInequalityPair, PbJoinEncodingType,
22};
23
24use super::generic::{GenericPlanNode, Join};
25use super::stream::prelude::*;
26use super::stream_join_common::StreamJoinCommon;
27use super::utils::{Distill, childless_record, plan_node_name, watermark_pretty};
28use super::{
29    ExprRewritable, PlanBase, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode,
30    StreamPlanRef as PlanRef, generic,
31};
32use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputPair};
33use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
34use crate::optimizer::plan_node::utils::IndicesDisplay;
35use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay};
36use crate::optimizer::property::{MonotonicityMap, WatermarkColumns};
37use crate::stream_fragmenter::BuildFragmentGraphState;
38
39/// [`StreamHashJoin`] implements [`super::LogicalJoin`] with hash table. It builds a hash table
40/// from inner (right-side) relation and probes with data from outer (left-side) relation to
41/// get output rows.
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct StreamHashJoin {
44    pub base: PlanBase<Stream>,
45    core: generic::Join<PlanRef>,
46
47    /// The join condition must be equivalent to `logical.on`, but separated into equal and
48    /// non-equal parts to facilitate execution later
49    eq_join_predicate: EqJoinPredicate,
50
51    /// `(do_state_cleaning, InequalityInputPair {key_required_larger, key_required_smaller,
52    /// delta_expression})`. View struct `InequalityInputPair` for details.
53    inequality_pairs: Vec<(bool, InequalityInputPair)>,
54
55    /// Whether can optimize for append-only stream.
56    /// It is true if input of both side is append-only
57    is_append_only: bool,
58
59    /// The conjunction index of the inequality which is used to clean left state table in
60    /// `HashJoinExecutor`. If any equal condition is able to clean state table, this field
61    /// will always be `None`.
62    clean_left_state_conjunction_idx: Option<usize>,
63    /// The conjunction index of the inequality which is used to clean right state table in
64    /// `HashJoinExecutor`. If any equal condition is able to clean state table, this field
65    /// will always be `None`.
66    clean_right_state_conjunction_idx: Option<usize>,
67}
68
69impl StreamHashJoin {
70    pub fn new(core: generic::Join<PlanRef>, eq_join_predicate: EqJoinPredicate) -> Result<Self> {
71        let ctx = core.ctx();
72
73        let stream_kind = core.stream_kind()?;
74
75        let dist = StreamJoinCommon::derive_dist(
76            core.left.distribution(),
77            core.right.distribution(),
78            &core,
79        );
80
81        let mut inequality_pairs = vec![];
82        let mut clean_left_state_conjunction_idx = None;
83        let mut clean_right_state_conjunction_idx = None;
84
85        // Reorder `eq_join_predicate` by placing the watermark column at the beginning.
86        let mut reorder_idx = vec![];
87        for (i, (left_key, right_key)) in eq_join_predicate.eq_indexes().iter().enumerate() {
88            if core.left.watermark_columns().contains(*left_key)
89                && core.right.watermark_columns().contains(*right_key)
90            {
91                reorder_idx.push(i);
92            }
93        }
94        let eq_join_predicate = eq_join_predicate.reorder(&reorder_idx);
95
96        let watermark_columns = {
97            let l2i = core.l2i_col_mapping();
98            let r2i = core.r2i_col_mapping();
99
100            let mut equal_condition_clean_state = false;
101            let mut watermark_columns = WatermarkColumns::new();
102            for (left_key, right_key) in eq_join_predicate.eq_indexes() {
103                if let Some(l_wtmk_group) = core.left.watermark_columns().get_group(left_key)
104                    && let Some(r_wtmk_group) = core.right.watermark_columns().get_group(right_key)
105                {
106                    equal_condition_clean_state = true;
107                    if let Some(internal) = l2i.try_map(left_key) {
108                        watermark_columns.insert(
109                            internal,
110                            l_wtmk_group
111                                .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
112                        );
113                    }
114                    if let Some(internal) = r2i.try_map(right_key) {
115                        watermark_columns.insert(
116                            internal,
117                            l_wtmk_group
118                                .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
119                        );
120                    }
121                }
122            }
123            let (left_cols_num, original_inequality_pairs) = eq_join_predicate.inequality_pairs();
124            for (
125                conjunction_idx,
126                InequalityInputPair {
127                    key_required_larger,
128                    key_required_smaller,
129                    delta_expression,
130                },
131            ) in original_inequality_pairs
132            {
133                let both_upstream_has_watermark = if key_required_larger < key_required_smaller {
134                    core.left.watermark_columns().contains(key_required_larger)
135                        && core
136                            .right
137                            .watermark_columns()
138                            .contains(key_required_smaller - left_cols_num)
139                } else {
140                    core.left.watermark_columns().contains(key_required_smaller)
141                        && core
142                            .right
143                            .watermark_columns()
144                            .contains(key_required_larger - left_cols_num)
145                };
146                if !both_upstream_has_watermark {
147                    continue;
148                }
149
150                let (internal_col1, internal_col2, do_state_cleaning) =
151                    if key_required_larger < key_required_smaller {
152                        (
153                            l2i.try_map(key_required_larger),
154                            r2i.try_map(key_required_smaller - left_cols_num),
155                            if !equal_condition_clean_state
156                                && clean_left_state_conjunction_idx.is_none()
157                            {
158                                clean_left_state_conjunction_idx = Some(conjunction_idx);
159                                true
160                            } else {
161                                false
162                            },
163                        )
164                    } else {
165                        (
166                            r2i.try_map(key_required_larger - left_cols_num),
167                            l2i.try_map(key_required_smaller),
168                            if !equal_condition_clean_state
169                                && clean_right_state_conjunction_idx.is_none()
170                            {
171                                clean_right_state_conjunction_idx = Some(conjunction_idx);
172                                true
173                            } else {
174                                false
175                            },
176                        )
177                    };
178                let mut is_valuable_inequality = do_state_cleaning;
179                if let Some(internal) = internal_col1
180                    && !watermark_columns.contains(internal)
181                {
182                    watermark_columns.insert(internal, ctx.next_watermark_group_id());
183                    is_valuable_inequality = true;
184                }
185                if let Some(internal) = internal_col2
186                    && !watermark_columns.contains(internal)
187                {
188                    watermark_columns.insert(internal, ctx.next_watermark_group_id());
189                }
190                if is_valuable_inequality {
191                    inequality_pairs.push((
192                        do_state_cleaning,
193                        InequalityInputPair {
194                            key_required_larger,
195                            key_required_smaller,
196                            delta_expression,
197                        },
198                    ));
199                }
200            }
201            watermark_columns.map_clone(&core.i2o_col_mapping())
202        };
203
204        // TODO: derive from input
205        let base = PlanBase::new_stream_with_core(
206            &core,
207            dist,
208            stream_kind,
209            false, // TODO(rc): derive EOWC property from input
210            watermark_columns,
211            MonotonicityMap::new(), // TODO: derive monotonicity
212        );
213
214        Ok(Self {
215            base,
216            core,
217            eq_join_predicate,
218            inequality_pairs,
219            is_append_only: stream_kind.is_append_only(),
220            clean_left_state_conjunction_idx,
221            clean_right_state_conjunction_idx,
222        })
223    }
224
225    /// Get join type
226    pub fn join_type(&self) -> JoinType {
227        self.core.join_type
228    }
229
230    /// Get a reference to the hash join's eq join predicate.
231    pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
232        &self.eq_join_predicate
233    }
234
235    /// Convert this hash join to a delta join plan
236    pub fn into_delta_join(self) -> StreamDeltaJoin {
237        StreamDeltaJoin::new(self.core, self.eq_join_predicate).unwrap()
238    }
239
240    pub fn derive_dist_key_in_join_key(&self) -> Vec<usize> {
241        let left_dk_indices = self.left().distribution().dist_column_indices().to_vec();
242        let right_dk_indices = self.right().distribution().dist_column_indices().to_vec();
243
244        StreamJoinCommon::get_dist_key_in_join_key(
245            &left_dk_indices,
246            &right_dk_indices,
247            self.eq_join_predicate(),
248        )
249    }
250
251    pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> {
252        &self.inequality_pairs
253    }
254}
255
256impl Distill for StreamHashJoin {
257    fn distill<'a>(&self) -> XmlNode<'a> {
258        let (ljk, rjk) = self
259            .eq_join_predicate
260            .eq_indexes()
261            .first()
262            .cloned()
263            .expect("first join key");
264
265        let name = plan_node_name!("StreamHashJoin",
266            { "window", self.left().watermark_columns().contains(ljk) && self.right().watermark_columns().contains(rjk) },
267            { "interval", self.clean_left_state_conjunction_idx.is_some() && self.clean_right_state_conjunction_idx.is_some() },
268            { "append_only", self.is_append_only },
269        );
270        let verbose = self.base.ctx().is_explain_verbose();
271        let mut vec = Vec::with_capacity(6);
272        vec.push(("type", Pretty::debug(&self.core.join_type)));
273
274        let concat_schema = self.core.concat_schema();
275        vec.push((
276            "predicate",
277            Pretty::debug(&EqJoinPredicateDisplay {
278                eq_join_predicate: self.eq_join_predicate(),
279                input_schema: &concat_schema,
280            }),
281        ));
282
283        let get_cond = |conjunction_idx| {
284            Pretty::debug(&ExprDisplay {
285                expr: &self.eq_join_predicate().other_cond().conjunctions[conjunction_idx],
286                input_schema: &concat_schema,
287            })
288        };
289        if let Some(i) = self.clean_left_state_conjunction_idx {
290            vec.push(("conditions_to_clean_left_state_table", get_cond(i)));
291        }
292        if let Some(i) = self.clean_right_state_conjunction_idx {
293            vec.push(("conditions_to_clean_right_state_table", get_cond(i)));
294        }
295        if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) {
296            vec.push(("output_watermarks", ow));
297        }
298
299        if verbose {
300            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
301            vec.push(("output", data));
302        }
303
304        childless_record(name, vec)
305    }
306}
307
308impl PlanTreeNodeBinary<Stream> for StreamHashJoin {
309    fn left(&self) -> PlanRef {
310        self.core.left.clone()
311    }
312
313    fn right(&self) -> PlanRef {
314        self.core.right.clone()
315    }
316
317    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
318        let mut core = self.core.clone();
319        core.left = left;
320        core.right = right;
321        Self::new(core, self.eq_join_predicate.clone()).unwrap()
322    }
323}
324
325impl_plan_tree_node_for_binary! { Stream, StreamHashJoin }
326
327impl StreamNode for StreamHashJoin {
328    fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> NodeBody {
329        let left_jk_indices = self.eq_join_predicate.left_eq_indexes();
330        let right_jk_indices = self.eq_join_predicate.right_eq_indexes();
331        let left_jk_indices_prost = left_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
332        let right_jk_indices_prost = right_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
333
334        let dk_indices_in_jk = self.derive_dist_key_in_join_key();
335
336        let (left_table, left_degree_table, left_deduped_input_pk_indices) =
337            Join::infer_internal_and_degree_table_catalog(
338                self.left(),
339                left_jk_indices,
340                dk_indices_in_jk.clone(),
341            );
342        let (right_table, right_degree_table, right_deduped_input_pk_indices) =
343            Join::infer_internal_and_degree_table_catalog(
344                self.right(),
345                right_jk_indices,
346                dk_indices_in_jk,
347            );
348
349        let left_deduped_input_pk_indices = left_deduped_input_pk_indices
350            .iter()
351            .map(|idx| *idx as u32)
352            .collect_vec();
353
354        let right_deduped_input_pk_indices = right_deduped_input_pk_indices
355            .iter()
356            .map(|idx| *idx as u32)
357            .collect_vec();
358
359        let (left_table, left_degree_table) = (
360            left_table.with_id(state.gen_table_id_wrapped()),
361            left_degree_table.with_id(state.gen_table_id_wrapped()),
362        );
363        let (right_table, right_degree_table) = (
364            right_table.with_id(state.gen_table_id_wrapped()),
365            right_degree_table.with_id(state.gen_table_id_wrapped()),
366        );
367
368        let null_safe_prost = self.eq_join_predicate.null_safes().into_iter().collect();
369
370        NodeBody::HashJoin(Box::new(HashJoinNode {
371            join_type: self.core.join_type as i32,
372            left_key: left_jk_indices_prost,
373            right_key: right_jk_indices_prost,
374            null_safe: null_safe_prost,
375            condition: self
376                .eq_join_predicate
377                .other_cond()
378                .as_expr_unless_true()
379                .map(|x| x.to_expr_proto()),
380            inequality_pairs: self
381                .inequality_pairs
382                .iter()
383                .map(
384                    |(
385                        do_state_clean,
386                        InequalityInputPair {
387                            key_required_larger,
388                            key_required_smaller,
389                            delta_expression,
390                        },
391                    )| {
392                        PbInequalityPair {
393                            key_required_larger: *key_required_larger as u32,
394                            key_required_smaller: *key_required_smaller as u32,
395                            clean_state: *do_state_clean,
396                            delta_expression: delta_expression.as_ref().map(
397                                |(delta_type, delta)| DeltaExpression {
398                                    delta_type: *delta_type as i32,
399                                    delta: Some(delta.to_expr_proto()),
400                                },
401                            ),
402                        }
403                    },
404                )
405                .collect_vec(),
406            left_table: Some(left_table.to_internal_table_prost()),
407            right_table: Some(right_table.to_internal_table_prost()),
408            left_degree_table: Some(left_degree_table.to_internal_table_prost()),
409            right_degree_table: Some(right_degree_table.to_internal_table_prost()),
410            left_deduped_input_pk_indices,
411            right_deduped_input_pk_indices,
412            output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
413            is_append_only: self.is_append_only,
414            // Join encoding type should now be read from per-job config override.
415            #[allow(deprecated)]
416            join_encoding_type: PbJoinEncodingType::Unspecified as _,
417        }))
418    }
419}
420
421impl ExprRewritable<Stream> for StreamHashJoin {
422    fn has_rewritable_expr(&self) -> bool {
423        true
424    }
425
426    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
427        let mut core = self.core.clone();
428        core.rewrite_exprs(r);
429        Self::new(core, self.eq_join_predicate.rewrite_exprs(r))
430            .unwrap()
431            .into()
432    }
433}
434
435impl ExprVisitable for StreamHashJoin {
436    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
437        self.core.visit_exprs(v);
438    }
439}