risingwave_frontend/optimizer/plan_node/
stream_hash_join.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::util::functional::SameOrElseExt;
18use risingwave_pb::plan_common::JoinType;
19use risingwave_pb::stream_plan::stream_node::NodeBody;
20use risingwave_pb::stream_plan::{
21    HashJoinNode, InequalityPairV2 as PbInequalityPairV2, InequalityType as PbInequalityType,
22    PbJoinEncodingType,
23};
24
25use super::generic::{GenericPlanNode, Join};
26use super::stream::prelude::*;
27use super::stream_join_common::StreamJoinCommon;
28use super::utils::{Distill, childless_record, plan_node_name, watermark_pretty};
29use super::{
30    ExprRewritable, PlanBase, PlanTreeNodeBinary, StreamDeltaJoin, StreamPlanRef as PlanRef,
31    TryToStreamPb, generic,
32};
33use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprType, ExprVisitor, InequalityInputPair};
34use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
35use crate::optimizer::plan_node::utils::IndicesDisplay;
36use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay};
37use crate::optimizer::property::{MonotonicityMap, WatermarkColumns};
38use crate::scheduler::SchedulerResult;
39use crate::stream_fragmenter::BuildFragmentGraphState;
40
41/// [`StreamHashJoin`] implements [`super::LogicalJoin`] with hash table. It builds a hash table
42/// from inner (right-side) relation and probes with data from outer (left-side) relation to
43/// get output rows.
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct StreamHashJoin {
46    pub base: PlanBase<Stream>,
47    core: generic::Join<PlanRef>,
48
49    /// `(clean_left_state, clean_right_state, InequalityInputPair)`.
50    /// Each entry represents an inequality condition like `left_col <op> right_col`.
51    inequality_pairs: Vec<(bool, bool, InequalityInputPair)>,
52
53    /// Whether can optimize for append-only stream.
54    /// It is true if input of both side is append-only
55    is_append_only: bool,
56
57    /// The conjunction index of the inequality which is used to clean left state table in
58    /// `HashJoinExecutor`. If any equal condition is able to clean state table, this field
59    /// will always be `None`.
60    clean_left_state_conjunction_idx: Option<usize>,
61    /// The conjunction index of the inequality which is used to clean right state table in
62    /// `HashJoinExecutor`. If any equal condition is able to clean state table, this field
63    /// will always be `None`.
64    clean_right_state_conjunction_idx: Option<usize>,
65}
66
67impl StreamHashJoin {
68    pub fn new(mut core: generic::Join<PlanRef>) -> Result<Self> {
69        let ctx = core.ctx();
70
71        let stream_kind = core.stream_kind()?;
72
73        // Reorder `eq_join_predicate` by placing the watermark column at the beginning.
74        let eq_join_predicate = {
75            let eq_join_predicate = core
76                .on
77                .as_eq_predicate_ref()
78                .expect("StreamHashJoin requires JoinOn::EqPredicate in core")
79                .clone();
80            let mut reorder_idx = vec![];
81            for (i, (left_key, right_key)) in eq_join_predicate.eq_indexes().iter().enumerate() {
82                if core.left.watermark_columns().contains(*left_key)
83                    && core.right.watermark_columns().contains(*right_key)
84                {
85                    reorder_idx.push(i);
86                }
87            }
88            eq_join_predicate.reorder(&reorder_idx)
89        };
90        core.on = generic::JoinOn::EqPredicate(eq_join_predicate.clone());
91
92        let dist = StreamJoinCommon::derive_dist(
93            core.left.distribution(),
94            core.right.distribution(),
95            &core,
96        );
97
98        let mut inequality_pairs = vec![];
99        let mut clean_left_state_conjunction_idx = None;
100        let mut clean_right_state_conjunction_idx = None;
101
102        let watermark_columns = {
103            let l2i = core.l2i_col_mapping();
104            let r2i = core.r2i_col_mapping();
105
106            let mut equal_condition_clean_state = false;
107            let mut watermark_columns = WatermarkColumns::new();
108            for (left_key, right_key) in eq_join_predicate.eq_indexes() {
109                if let Some(l_wtmk_group) = core.left.watermark_columns().get_group(left_key)
110                    && let Some(r_wtmk_group) = core.right.watermark_columns().get_group(right_key)
111                {
112                    equal_condition_clean_state = true;
113                    if let Some(internal) = l2i.try_map(left_key) {
114                        watermark_columns.insert(
115                            internal,
116                            l_wtmk_group
117                                .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
118                        );
119                    }
120                    if let Some(internal) = r2i.try_map(right_key) {
121                        watermark_columns.insert(
122                            internal,
123                            l_wtmk_group
124                                .same_or_else(r_wtmk_group, || ctx.next_watermark_group_id()),
125                        );
126                    }
127                }
128            }
129
130            // Process inequality pairs using the new V2 format
131            let original_inequality_pairs = eq_join_predicate.inequality_pairs_v2();
132            for (conjunction_idx, pair) in original_inequality_pairs {
133                let InequalityInputPair {
134                    left_idx,
135                    right_idx,
136                    op,
137                } = pair;
138
139                // Check if both upstream sides have watermarks on the inequality columns
140                let both_upstream_has_watermark = core.left.watermark_columns().contains(left_idx)
141                    && core.right.watermark_columns().contains(right_idx);
142                if !both_upstream_has_watermark {
143                    continue;
144                }
145
146                // Determine which side's state can be cleaned based on the operator.
147                // State cleanup applies to the side with LARGER values.
148                // For `left < right` or `left <= right`: RIGHT is larger → clean RIGHT state
149                // For `left > right` or `left >= right`: LEFT is larger → clean LEFT state
150                let left_is_larger =
151                    matches!(op, ExprType::GreaterThan | ExprType::GreaterThanOrEqual);
152
153                let (clean_left, clean_right) = if left_is_larger {
154                    // Left side is larger, we can clean left state
155                    let do_clean =
156                        !equal_condition_clean_state && clean_left_state_conjunction_idx.is_none();
157                    if do_clean {
158                        clean_left_state_conjunction_idx = Some(conjunction_idx);
159                    }
160                    (do_clean, false)
161                } else {
162                    // Right side is larger, we can clean right state
163                    let do_clean =
164                        !equal_condition_clean_state && clean_right_state_conjunction_idx.is_none();
165                    if do_clean {
166                        clean_right_state_conjunction_idx = Some(conjunction_idx);
167                    }
168                    (false, do_clean)
169                };
170
171                let mut is_valuable_inequality = clean_left || clean_right;
172
173                // Add watermark columns for the inequality.
174                // We can only yield watermark from the LARGER side downstream.
175                // For `left >= right`: left is larger, yield left watermark
176                // For `left <= right`: right is larger, yield right watermark
177                if left_is_larger {
178                    if let Some(internal) = l2i.try_map(left_idx)
179                        && !watermark_columns.contains(internal)
180                    {
181                        watermark_columns.insert(internal, ctx.next_watermark_group_id());
182                        is_valuable_inequality = true;
183                    }
184                } else if let Some(internal) = r2i.try_map(right_idx)
185                    && !watermark_columns.contains(internal)
186                {
187                    watermark_columns.insert(internal, ctx.next_watermark_group_id());
188                    is_valuable_inequality = true;
189                }
190
191                if is_valuable_inequality {
192                    inequality_pairs.push((
193                        clean_left,
194                        clean_right,
195                        InequalityInputPair::new(left_idx, right_idx, op),
196                    ));
197                }
198            }
199            watermark_columns.map_clone(&core.i2o_col_mapping())
200        };
201
202        // TODO: derive from input
203        let base = PlanBase::new_stream_with_core(
204            &core,
205            dist,
206            stream_kind,
207            false, // TODO(rc): derive EOWC property from input
208            watermark_columns,
209            MonotonicityMap::new(), // TODO: derive monotonicity
210        );
211
212        Ok(Self {
213            base,
214            core,
215            inequality_pairs,
216            is_append_only: stream_kind.is_append_only(),
217            clean_left_state_conjunction_idx,
218            clean_right_state_conjunction_idx,
219        })
220    }
221
222    /// Get join type
223    pub fn join_type(&self) -> JoinType {
224        self.core.join_type
225    }
226
227    /// Get a reference to the hash join's eq join predicate.
228    pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
229        self.core
230            .on
231            .as_eq_predicate_ref()
232            .expect("StreamHashJoin should store predicate as EqJoinPredicate")
233    }
234
235    /// Convert this hash join to a delta join plan
236    pub fn into_delta_join(self) -> StreamDeltaJoin {
237        StreamDeltaJoin::new(self.core).unwrap()
238    }
239
240    pub fn derive_dist_key_in_join_key(&self) -> Vec<usize> {
241        let left_dk_indices = self.left().distribution().dist_column_indices().to_vec();
242        let right_dk_indices = self.right().distribution().dist_column_indices().to_vec();
243
244        StreamJoinCommon::get_dist_key_in_join_key(
245            &left_dk_indices,
246            &right_dk_indices,
247            self.eq_join_predicate(),
248        )
249    }
250
251    pub fn inequality_pairs(&self) -> &Vec<(bool, bool, InequalityInputPair)> {
252        &self.inequality_pairs
253    }
254}
255
256impl Distill for StreamHashJoin {
257    fn distill<'a>(&self) -> XmlNode<'a> {
258        let (ljk, rjk) = self
259            .eq_join_predicate()
260            .eq_indexes()
261            .first()
262            .cloned()
263            .expect("first join key");
264
265        let name = plan_node_name!("StreamHashJoin",
266            { "window", self.left().watermark_columns().contains(ljk) && self.right().watermark_columns().contains(rjk) },
267            { "interval", self.clean_left_state_conjunction_idx.is_some() && self.clean_right_state_conjunction_idx.is_some() },
268            { "append_only", self.is_append_only },
269        );
270        let verbose = self.base.ctx().is_explain_verbose();
271        let mut vec = Vec::with_capacity(6);
272        vec.push(("type", Pretty::debug(&self.core.join_type)));
273
274        let concat_schema = self.core.concat_schema();
275        vec.push((
276            "predicate",
277            Pretty::debug(&EqJoinPredicateDisplay {
278                eq_join_predicate: self.eq_join_predicate(),
279                input_schema: &concat_schema,
280            }),
281        ));
282
283        let get_cond = |conjunction_idx| {
284            Pretty::debug(&ExprDisplay {
285                expr: &self.eq_join_predicate().other_cond().conjunctions[conjunction_idx],
286                input_schema: &concat_schema,
287            })
288        };
289        if let Some(i) = self.clean_left_state_conjunction_idx {
290            vec.push(("conditions_to_clean_left_state_table", get_cond(i)));
291        }
292        if let Some(i) = self.clean_right_state_conjunction_idx {
293            vec.push(("conditions_to_clean_right_state_table", get_cond(i)));
294        }
295        if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) {
296            vec.push(("output_watermarks", ow));
297        }
298
299        if verbose {
300            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
301            vec.push(("output", data));
302        }
303
304        childless_record(name, vec)
305    }
306}
307
308impl PlanTreeNodeBinary<Stream> for StreamHashJoin {
309    fn left(&self) -> PlanRef {
310        self.core.left.clone()
311    }
312
313    fn right(&self) -> PlanRef {
314        self.core.right.clone()
315    }
316
317    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
318        let mut core = self.core.clone();
319        core.left = left;
320        core.right = right;
321        Self::new(core).unwrap()
322    }
323}
324
325impl_plan_tree_node_for_binary! { Stream, StreamHashJoin }
326
327impl TryToStreamPb for StreamHashJoin {
328    fn try_to_stream_prost_body(
329        &self,
330        state: &mut BuildFragmentGraphState,
331    ) -> SchedulerResult<NodeBody> {
332        let left_jk_indices = self.eq_join_predicate().left_eq_indexes();
333        let right_jk_indices = self.eq_join_predicate().right_eq_indexes();
334        let left_jk_indices_prost = left_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
335        let right_jk_indices_prost = right_jk_indices.iter().map(|idx| *idx as i32).collect_vec();
336
337        let retract =
338            self.left().stream_kind().is_retract() || self.right().stream_kind().is_retract();
339
340        let dk_indices_in_jk = self.derive_dist_key_in_join_key();
341
342        let (left_table, left_degree_table, left_deduped_input_pk_indices) =
343            Join::infer_internal_and_degree_table_catalog(
344                self.left(),
345                left_jk_indices,
346                dk_indices_in_jk.clone(),
347            );
348        let (right_table, right_degree_table, right_deduped_input_pk_indices) =
349            Join::infer_internal_and_degree_table_catalog(
350                self.right(),
351                right_jk_indices,
352                dk_indices_in_jk,
353            );
354
355        let left_deduped_input_pk_indices = left_deduped_input_pk_indices
356            .iter()
357            .map(|idx| *idx as u32)
358            .collect_vec();
359
360        let right_deduped_input_pk_indices = right_deduped_input_pk_indices
361            .iter()
362            .map(|idx| *idx as u32)
363            .collect_vec();
364
365        let (left_table, left_degree_table) = (
366            left_table.with_id(state.gen_table_id_wrapped()),
367            left_degree_table.with_id(state.gen_table_id_wrapped()),
368        );
369        let (right_table, right_degree_table) = (
370            right_table.with_id(state.gen_table_id_wrapped()),
371            right_degree_table.with_id(state.gen_table_id_wrapped()),
372        );
373
374        let null_safe_prost = self.eq_join_predicate().null_safes().into_iter().collect();
375
376        let condition = self
377            .eq_join_predicate()
378            .other_cond()
379            .as_expr_unless_true()
380            .map(|expr| expr.to_expr_proto_checked_pure(retract, "JOIN condition"))
381            .transpose()?;
382
383        // Helper function to convert ExprType to PbInequalityType
384        fn expr_type_to_pb_inequality_type(op: ExprType) -> i32 {
385            match op {
386                ExprType::LessThan => PbInequalityType::LessThan as i32,
387                ExprType::LessThanOrEqual => PbInequalityType::LessThanOrEqual as i32,
388                ExprType::GreaterThan => PbInequalityType::GreaterThan as i32,
389                ExprType::GreaterThanOrEqual => PbInequalityType::GreaterThanOrEqual as i32,
390                _ => PbInequalityType::Unspecified as i32,
391            }
392        }
393
394        Ok(NodeBody::HashJoin(Box::new(HashJoinNode {
395            join_type: self.core.join_type as i32,
396            left_key: left_jk_indices_prost,
397            right_key: right_jk_indices_prost,
398            null_safe: null_safe_prost,
399            condition,
400            // Deprecated: keep empty for new plans
401            inequality_pairs: vec![],
402            // New inequality pairs with clearer semantics
403            inequality_pairs_v2: self
404                .inequality_pairs
405                .iter()
406                .map(|(clean_left, clean_right, pair)| PbInequalityPairV2 {
407                    left_idx: pair.left_idx as u32,
408                    right_idx: pair.right_idx as u32,
409                    clean_left_state: *clean_left,
410                    clean_right_state: *clean_right,
411                    op: expr_type_to_pb_inequality_type(pair.op),
412                })
413                .collect_vec(),
414            left_table: Some(left_table.to_internal_table_prost()),
415            right_table: Some(right_table.to_internal_table_prost()),
416            left_degree_table: Some(left_degree_table.to_internal_table_prost()),
417            right_degree_table: Some(right_degree_table.to_internal_table_prost()),
418            left_deduped_input_pk_indices,
419            right_deduped_input_pk_indices,
420            output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
421            is_append_only: self.is_append_only,
422            // Join encoding type should now be read from per-job config override.
423            #[allow(deprecated)]
424            join_encoding_type: PbJoinEncodingType::Unspecified as _,
425        })))
426    }
427}
428
429impl ExprRewritable<Stream> for StreamHashJoin {
430    fn has_rewritable_expr(&self) -> bool {
431        true
432    }
433
434    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
435        let mut core = self.core.clone();
436        core.rewrite_exprs(r);
437        Self::new(core).unwrap().into()
438    }
439}
440
441impl ExprVisitable for StreamHashJoin {
442    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
443        self.core.visit_exprs(v);
444    }
445}