risingwave_frontend/optimizer/plan_node/
stream_join_common.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::util::iter_util::ZipEqFast;
17use risingwave_pb::plan_common::JoinType;
18
19use super::{EqJoinPredicate, generic};
20use crate::PlanRef;
21use crate::optimizer::property::Distribution;
22use crate::utils::ColIndexMappingRewriteExt;
23
24pub struct StreamJoinCommon;
25
26impl StreamJoinCommon {
27    pub(super) fn get_dist_key_in_join_key(
28        left_dk_indices: &[usize],
29        right_dk_indices: &[usize],
30        eq_join_predicate: &EqJoinPredicate,
31    ) -> Vec<usize> {
32        let left_jk_indices = eq_join_predicate.left_eq_indexes();
33        let right_jk_indices = &eq_join_predicate.right_eq_indexes();
34        assert_eq!(left_jk_indices.len(), right_jk_indices.len());
35        let mut dk_indices_in_jk = vec![];
36        for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) {
37            for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) {
38                if right_jk_indices[dk_idx_in_jk] == *r_dk_idx {
39                    dk_indices_in_jk.push(dk_idx_in_jk);
40                    break;
41                }
42            }
43        }
44        assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len());
45        dk_indices_in_jk
46    }
47
48    pub(super) fn derive_dist(
49        left: &Distribution,
50        right: &Distribution,
51        logical: &generic::Join<PlanRef>,
52    ) -> Distribution {
53        match (left, right) {
54            (Distribution::Single, Distribution::Single) => Distribution::Single,
55            (Distribution::HashShard(_), Distribution::HashShard(_)) => {
56                // we can not derive the hash distribution from the side where outer join can
57                // generate a NULL row
58                match logical.join_type {
59                    JoinType::Unspecified => {
60                        unreachable!()
61                    }
62                    JoinType::FullOuter => Distribution::SomeShard,
63                    JoinType::Inner
64                    | JoinType::LeftOuter
65                    | JoinType::LeftSemi
66                    | JoinType::LeftAnti
67                    | JoinType::AsofInner
68                    | JoinType::AsofLeftOuter => {
69                        let l2o = logical
70                            .l2i_col_mapping()
71                            .composite(&logical.i2o_col_mapping());
72                        l2o.rewrite_provided_distribution(left)
73                    }
74                    JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
75                        let r2o = logical
76                            .r2i_col_mapping()
77                            .composite(&logical.i2o_col_mapping());
78                        r2o.rewrite_provided_distribution(right)
79                    }
80                }
81            }
82            (_, _) => unreachable!(
83                "suspicious distribution: left: {:?}, right: {:?}",
84                left, right
85            ),
86        }
87    }
88}