risingwave_frontend/optimizer/plan_node/generic/
join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::{EitherOrBoth, Itertools};
16use risingwave_common::catalog::{Field, Schema};
17use risingwave_common::types::DataType;
18use risingwave_common::util::sort_util::OrderType;
19use risingwave_pb::plan_common::JoinType;
20
21use super::{EqJoinPredicate, GenericPlanNode, GenericPlanRef};
22use crate::TableCatalog;
23use crate::expr::{ExprRewriter, ExprVisitor};
24use crate::optimizer::optimizer_context::OptimizerContextRef;
25use crate::optimizer::plan_node::stream;
26use crate::optimizer::plan_node::utils::TableCatalogBuilder;
27use crate::optimizer::property::FunctionalDependencySet;
28use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
29
30/// [`Join`] combines two relations according to some condition.
31///
32/// Each output row has fields from the left and right inputs. The set of output rows is a subset
33/// of the cartesian product of the two inputs; precisely which subset depends on the join
34/// condition. In addition, the output columns are a subset of the columns of the left and
35/// right columns, dependent on the output indices provided. A repeat output index is illegal.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct Join<PlanRef> {
38    pub left: PlanRef,
39    pub right: PlanRef,
40    pub on: Condition,
41    pub join_type: JoinType,
42    pub output_indices: Vec<usize>,
43}
44
45pub(crate) fn has_repeated_element(slice: &[usize]) -> bool {
46    (1..slice.len()).any(|i| slice[i..].contains(&slice[i - 1]))
47}
48
49impl<PlanRef: GenericPlanRef> Join<PlanRef> {
50    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
51        self.on = self.on.clone().rewrite_expr(r);
52    }
53
54    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
55        self.on.visit_expr(v);
56    }
57
58    pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
59        let left_len = self.left.schema().len();
60        let right_len = self.right.schema().len();
61        let eq_predicate = EqJoinPredicate::create(left_len, right_len, self.on.clone());
62        eq_predicate.eq_indexes()
63    }
64
65    pub fn new(
66        left: PlanRef,
67        right: PlanRef,
68        on: Condition,
69        join_type: JoinType,
70        output_indices: Vec<usize>,
71    ) -> Self {
72        // We cannot deal with repeated output indices in join
73        debug_assert!(!has_repeated_element(&output_indices));
74        Self {
75            left,
76            right,
77            on,
78            join_type,
79            output_indices,
80        }
81    }
82}
83
84impl<I: stream::StreamPlanRef> Join<I> {
85    /// Return stream hash join internal table catalog and degree table catalog.
86    pub fn infer_internal_and_degree_table_catalog(
87        input: I,
88        join_key_indices: Vec<usize>,
89        dk_indices_in_jk: Vec<usize>,
90    ) -> (TableCatalog, TableCatalog, Vec<usize>) {
91        let schema = input.schema();
92
93        let internal_table_dist_keys = dk_indices_in_jk
94            .iter()
95            .map(|idx| join_key_indices[*idx])
96            .collect_vec();
97
98        let degree_table_dist_keys = dk_indices_in_jk.clone();
99
100        // The pk of hash join internal and degree table should be join_key + input_pk.
101        let join_key_len = join_key_indices.len();
102        let mut pk_indices = join_key_indices;
103
104        // dedup the pk in dist key..
105        let mut deduped_input_pk_indices = vec![];
106        for input_pk_idx in input.stream_key().unwrap() {
107            if !pk_indices.contains(input_pk_idx)
108                && !deduped_input_pk_indices.contains(input_pk_idx)
109            {
110                deduped_input_pk_indices.push(*input_pk_idx);
111            }
112        }
113
114        pk_indices.extend(deduped_input_pk_indices.clone());
115
116        // Build internal table
117        let mut internal_table_catalog_builder = TableCatalogBuilder::default();
118        let internal_columns_fields = schema.fields().to_vec();
119
120        internal_columns_fields.iter().for_each(|field| {
121            internal_table_catalog_builder.add_column(field);
122        });
123        pk_indices.iter().for_each(|idx| {
124            internal_table_catalog_builder.add_order_column(*idx, OrderType::ascending())
125        });
126
127        // Build degree table.
128        let mut degree_table_catalog_builder = TableCatalogBuilder::default();
129
130        let degree_column_field = Field::with_name(DataType::Int64, "_degree");
131
132        pk_indices.iter().enumerate().for_each(|(order_idx, idx)| {
133            degree_table_catalog_builder.add_column(&internal_columns_fields[*idx]);
134            degree_table_catalog_builder.add_order_column(order_idx, OrderType::ascending());
135        });
136        degree_table_catalog_builder.add_column(&degree_column_field);
137        degree_table_catalog_builder
138            .set_value_indices(vec![degree_table_catalog_builder.columns().len() - 1]);
139
140        internal_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk.clone());
141        degree_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk);
142
143        (
144            internal_table_catalog_builder.build(internal_table_dist_keys, join_key_len),
145            degree_table_catalog_builder.build(degree_table_dist_keys, join_key_len),
146            deduped_input_pk_indices,
147        )
148    }
149}
150
151impl<PlanRef: GenericPlanRef> GenericPlanNode for Join<PlanRef> {
152    fn schema(&self) -> Schema {
153        let left_schema = self.left.schema();
154        let right_schema = self.right.schema();
155        let i2l = self.i2l_col_mapping();
156        let i2r = self.i2r_col_mapping();
157        let fields = self
158            .output_indices
159            .iter()
160            .map(|&i| match (i2l.try_map(i), i2r.try_map(i)) {
161                (Some(l_i), None) => left_schema.fields()[l_i].clone(),
162                (None, Some(r_i)) => right_schema.fields()[r_i].clone(),
163                _ => panic!(
164                    "left len {}, right len {}, i {}, lmap {:?}, rmap {:?}",
165                    left_schema.len(),
166                    right_schema.len(),
167                    i,
168                    i2l,
169                    i2r
170                ),
171            })
172            .collect();
173        Schema { fields }
174    }
175
176    fn stream_key(&self) -> Option<Vec<usize>> {
177        let eq_indexes = self.eq_indexes();
178        let left_pk = self.left.stream_key()?;
179        let right_pk = self.right.stream_key()?;
180        let l2i = self.l2i_col_mapping();
181        let r2i = self.r2i_col_mapping();
182        let full_out_col_num = self.internal_column_num();
183        let i2o = ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num);
184
185        let mut pk_indices = left_pk
186            .iter()
187            .map(|index| l2i.try_map(*index))
188            .chain(right_pk.iter().map(|index| r2i.try_map(*index)))
189            .flatten()
190            .map(|index| i2o.try_map(index))
191            .collect::<Option<Vec<_>>>()?;
192
193        // NOTE(st1page): add join keys in the pk_indices a work around before we really have stream
194        // key.
195        let l2i = self.l2i_col_mapping();
196        let r2i = self.r2i_col_mapping();
197        let full_out_col_num = self.internal_column_num();
198        let i2o = ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num);
199
200        let either_or_both = self.add_which_join_key_to_pk();
201
202        for (lk, rk) in eq_indexes {
203            match either_or_both {
204                EitherOrBoth::Left(_) => {
205                    // Remove right-side join-key column it from pk_indices.
206                    // This may happen when right-side join-key is included in right-side PK.
207                    // e.g. select a, b where a.bid = b.id
208                    // Here the pk_indices should be [a.id, a.bid] instead of [a.id, b.id, a.bid],
209                    // because b.id = a.bid, so either of them would be enough.
210                    if let Some(rk) = r2i.try_map(rk) {
211                        if let Some(out_k) = i2o.try_map(rk) {
212                            pk_indices.retain(|&x| x != out_k);
213                        }
214                    }
215                    // Add left-side join-key column in pk_indices
216                    if let Some(lk) = l2i.try_map(lk) {
217                        let out_k = i2o.try_map(lk)?;
218                        if !pk_indices.contains(&out_k) {
219                            pk_indices.push(out_k);
220                        }
221                    }
222                }
223                EitherOrBoth::Right(_) => {
224                    // Remove left-side join-key column it from pk_indices
225                    // See the example above
226                    if let Some(lk) = l2i.try_map(lk) {
227                        if let Some(out_k) = i2o.try_map(lk) {
228                            pk_indices.retain(|&x| x != out_k);
229                        }
230                    }
231                    // Add right-side join-key column in pk_indices
232                    if let Some(rk) = r2i.try_map(rk) {
233                        let out_k = i2o.try_map(rk)?;
234                        if !pk_indices.contains(&out_k) {
235                            pk_indices.push(out_k);
236                        }
237                    }
238                }
239                EitherOrBoth::Both(_, _) => {
240                    if let Some(lk) = l2i.try_map(lk) {
241                        let out_k = i2o.try_map(lk)?;
242                        if !pk_indices.contains(&out_k) {
243                            pk_indices.push(out_k);
244                        }
245                    }
246                    if let Some(rk) = r2i.try_map(rk) {
247                        let out_k = i2o.try_map(rk)?;
248                        if !pk_indices.contains(&out_k) {
249                            pk_indices.push(out_k);
250                        }
251                    }
252                }
253            };
254        }
255        Some(pk_indices)
256    }
257
258    fn ctx(&self) -> OptimizerContextRef {
259        self.left.ctx()
260    }
261
262    fn functional_dependency(&self) -> FunctionalDependencySet {
263        let left_len = self.left.schema().len();
264        let right_len = self.right.schema().len();
265        let left_fd_set = self.left.functional_dependency().clone();
266        let right_fd_set = self.right.functional_dependency().clone();
267
268        let full_out_col_num = self.internal_column_num();
269
270        let get_new_left_fd_set = |left_fd_set: FunctionalDependencySet| {
271            ColIndexMapping::with_shift_offset(left_len, 0)
272                .composite(&ColIndexMapping::identity(full_out_col_num))
273                .rewrite_functional_dependency_set(left_fd_set)
274        };
275        let get_new_right_fd_set = |right_fd_set: FunctionalDependencySet| {
276            ColIndexMapping::with_shift_offset(right_len, left_len.try_into().unwrap())
277                .rewrite_functional_dependency_set(right_fd_set)
278        };
279        let fd_set: FunctionalDependencySet = match self.join_type {
280            JoinType::Inner | JoinType::AsofInner => {
281                let mut fd_set = FunctionalDependencySet::new(full_out_col_num);
282                for i in &self.on.conjunctions {
283                    if let Some((col, _)) = i.as_eq_const() {
284                        fd_set.add_constant_columns(&[col.index()])
285                    } else if let Some((left, right)) = i.as_eq_cond() {
286                        fd_set.add_functional_dependency_by_column_indices(
287                            &[left.index()],
288                            &[right.index()],
289                        );
290                        fd_set.add_functional_dependency_by_column_indices(
291                            &[right.index()],
292                            &[left.index()],
293                        );
294                    }
295                }
296                get_new_left_fd_set(left_fd_set)
297                    .into_dependencies()
298                    .into_iter()
299                    .chain(get_new_right_fd_set(right_fd_set).into_dependencies())
300                    .for_each(|fd| fd_set.add_functional_dependency(fd));
301                fd_set
302            }
303            JoinType::LeftOuter | JoinType::AsofLeftOuter => get_new_left_fd_set(left_fd_set),
304            JoinType::RightOuter => get_new_right_fd_set(right_fd_set),
305            JoinType::FullOuter => FunctionalDependencySet::new(full_out_col_num),
306            JoinType::LeftSemi | JoinType::LeftAnti => left_fd_set,
307            JoinType::RightSemi | JoinType::RightAnti => right_fd_set,
308            JoinType::Unspecified => unreachable!(),
309        };
310        ColIndexMapping::with_remaining_columns(&self.output_indices, full_out_col_num)
311            .rewrite_functional_dependency_set(fd_set)
312    }
313}
314
315impl<PlanRef> Join<PlanRef> {
316    pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
317        (
318            self.left,
319            self.right,
320            self.on,
321            self.join_type,
322            self.output_indices,
323        )
324    }
325
326    pub fn full_out_col_num(left_len: usize, right_len: usize, join_type: JoinType) -> usize {
327        match join_type {
328            JoinType::Inner
329            | JoinType::LeftOuter
330            | JoinType::RightOuter
331            | JoinType::FullOuter
332            | JoinType::AsofInner
333            | JoinType::AsofLeftOuter => left_len + right_len,
334            JoinType::LeftSemi | JoinType::LeftAnti => left_len,
335            JoinType::RightSemi | JoinType::RightAnti => right_len,
336            JoinType::Unspecified => unreachable!(),
337        }
338    }
339}
340
341impl<PlanRef: GenericPlanRef> Join<PlanRef> {
342    pub fn with_full_output(
343        left: PlanRef,
344        right: PlanRef,
345        join_type: JoinType,
346        on: Condition,
347    ) -> Self {
348        let out_column_num =
349            Self::full_out_col_num(left.schema().len(), right.schema().len(), join_type);
350        Self {
351            left,
352            right,
353            join_type,
354            on,
355            output_indices: (0..out_column_num).collect(),
356        }
357    }
358
359    pub fn internal_column_num(&self) -> usize {
360        Self::full_out_col_num(
361            self.left.schema().len(),
362            self.right.schema().len(),
363            self.join_type,
364        )
365    }
366
367    pub fn is_full_out(&self) -> bool {
368        self.output_indices.len() == self.internal_column_num()
369    }
370
371    /// Get the Mapping of columnIndex from internal column index to left column index.
372    pub fn i2l_col_mapping(&self) -> ColIndexMapping {
373        let left_len = self.left.schema().len();
374        let right_len = self.right.schema().len();
375
376        match self.join_type {
377            JoinType::Inner
378            | JoinType::LeftOuter
379            | JoinType::RightOuter
380            | JoinType::FullOuter
381            | JoinType::AsofInner
382            | JoinType::AsofLeftOuter => {
383                ColIndexMapping::identity_or_none(left_len + right_len, left_len)
384            }
385
386            JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::identity(left_len),
387            JoinType::RightSemi | JoinType::RightAnti => {
388                ColIndexMapping::empty(right_len, left_len)
389            }
390            JoinType::Unspecified => unreachable!(),
391        }
392    }
393
394    /// Get the Mapping of columnIndex from internal column index to right column index.
395    pub fn i2r_col_mapping(&self) -> ColIndexMapping {
396        let left_len = self.left.schema().len();
397        let right_len = self.right.schema().len();
398
399        match self.join_type {
400            JoinType::Inner
401            | JoinType::LeftOuter
402            | JoinType::RightOuter
403            | JoinType::FullOuter
404            | JoinType::AsofInner
405            | JoinType::AsofLeftOuter => {
406                ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
407            }
408            JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::empty(left_len, right_len),
409            JoinType::RightSemi | JoinType::RightAnti => ColIndexMapping::identity(right_len),
410            JoinType::Unspecified => unreachable!(),
411        }
412    }
413
414    /// TODO: This function may can be merged with `i2l_col_mapping` in future.
415    pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
416        let left_len = self.left.schema().len();
417        let right_len = self.right.schema().len();
418
419        ColIndexMapping::identity_or_none(left_len + right_len, left_len)
420    }
421
422    /// TODO: This function may can be merged with `i2r_col_mapping` in future.
423    pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
424        let left_len = self.left.schema().len();
425        let right_len = self.right.schema().len();
426
427        ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
428    }
429
430    /// Get the Mapping of columnIndex from left column index to internal column index.
431    pub fn l2i_col_mapping(&self) -> ColIndexMapping {
432        self.i2l_col_mapping()
433            .inverse()
434            .expect("must be invertible")
435    }
436
437    /// Get the Mapping of columnIndex from right column index to internal column index.
438    pub fn r2i_col_mapping(&self) -> ColIndexMapping {
439        self.i2r_col_mapping()
440            .inverse()
441            .expect("must be invertible")
442    }
443
444    /// Get the Mapping of columnIndex from internal column index to output column index
445    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
446        ColIndexMapping::with_remaining_columns(&self.output_indices, self.internal_column_num())
447    }
448
449    /// Get the Mapping of columnIndex from output column index to internal column index
450    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
451        // If output_indices = [0, 0, 1], we should use it as `o2i_col_mapping` directly.
452        // If we use `self.i2o_col_mapping().inverse()`, we will lose the first 0.
453        ColIndexMapping::new(
454            self.output_indices.iter().map(|x| Some(*x)).collect(),
455            self.internal_column_num(),
456        )
457    }
458
459    pub fn add_which_join_key_to_pk(&self) -> EitherOrBoth<(), ()> {
460        match self.join_type {
461            JoinType::Inner | JoinType::AsofInner => {
462                // Theoretically adding either side is ok, but the distribution key of the inner
463                // join derived based on the left side by default, so we choose the left side here
464                // to ensure the pk comprises the distribution key.
465                EitherOrBoth::Left(())
466            }
467            JoinType::LeftOuter
468            | JoinType::LeftSemi
469            | JoinType::LeftAnti
470            | JoinType::AsofLeftOuter => EitherOrBoth::Left(()),
471            JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
472                EitherOrBoth::Right(())
473            }
474            JoinType::FullOuter => EitherOrBoth::Both((), ()),
475            JoinType::Unspecified => unreachable!(),
476        }
477    }
478
479    pub fn concat_schema(&self) -> Schema {
480        Schema::new(
481            [
482                self.left.schema().fields.clone(),
483                self.right.schema().fields.clone(),
484            ]
485            .concat(),
486        )
487    }
488}
489
490/// Try to split and pushdown `predicate` into a into a join condition and into the inputs of the
491/// join. Returns the pushed predicates. The pushed part will be removed from the original
492/// predicate.
493///
494/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
495pub fn push_down_into_join(
496    predicate: &mut Condition,
497    left_col_num: usize,
498    right_col_num: usize,
499    ty: JoinType,
500    push_temporal_predicate: bool,
501) -> (Condition, Condition, Condition) {
502    let (left, right) = push_down_to_inputs(
503        predicate,
504        left_col_num,
505        right_col_num,
506        can_push_left_from_filter(ty),
507        can_push_right_from_filter(ty),
508        push_temporal_predicate,
509    );
510
511    let on = if can_push_on_from_filter(ty) {
512        let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
513
514        if push_temporal_predicate {
515            Condition { conjunctions }
516        } else {
517            // Do not push now on to the on, it will be pulled up into a filter instead.
518            let on = Condition {
519                conjunctions: conjunctions
520                    .extract_if(.., |expr| expr.count_nows() == 0)
521                    .collect(),
522            };
523            predicate.conjunctions = conjunctions;
524            on
525        }
526    } else {
527        Condition::true_cond()
528    };
529    (left, right, on)
530}
531
532/// Try to pushes parts of the join condition to its inputs. Returns the pushed predicates. The
533/// pushed part will be removed from the original join predicate.
534///
535/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
536pub fn push_down_join_condition(
537    on_condition: &mut Condition,
538    left_col_num: usize,
539    right_col_num: usize,
540    ty: JoinType,
541    push_temporal_predicate: bool,
542) -> (Condition, Condition) {
543    push_down_to_inputs(
544        on_condition,
545        left_col_num,
546        right_col_num,
547        can_push_left_from_on(ty),
548        can_push_right_from_on(ty),
549        push_temporal_predicate,
550    )
551}
552
553/// Try to split and pushdown `predicate` into a join's left/right child.
554/// Returns the pushed predicates. The pushed part will be removed from the original predicate.
555///
556/// `InputRef`s in the right `Condition` are shifted by `-left_col_num`.
557fn push_down_to_inputs(
558    predicate: &mut Condition,
559    left_col_num: usize,
560    right_col_num: usize,
561    push_left: bool,
562    push_right: bool,
563    push_temporal_predicate: bool,
564) -> (Condition, Condition) {
565    let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
566    let (mut left, right, mut others) = if push_temporal_predicate {
567        Condition { conjunctions }.split(left_col_num, right_col_num)
568    } else {
569        let temporal_filter_cons = conjunctions
570            .extract_if(.., |e| e.count_nows() != 0)
571            .collect_vec();
572        let (left, right, mut others) =
573            Condition { conjunctions }.split(left_col_num, right_col_num);
574
575        others.conjunctions.extend(temporal_filter_cons);
576        (left, right, others)
577    };
578
579    if !push_left {
580        others.conjunctions.extend(left);
581        left = Condition::true_cond();
582    };
583
584    let right = if push_right {
585        let mut mapping = ColIndexMapping::with_shift_offset(
586            left_col_num + right_col_num,
587            -(left_col_num as isize),
588        );
589        right.rewrite_expr(&mut mapping)
590    } else {
591        others.conjunctions.extend(right);
592        Condition::true_cond()
593    };
594
595    predicate.conjunctions = others.conjunctions;
596
597    (left, right)
598}
599
600pub fn can_push_left_from_filter(ty: JoinType) -> bool {
601    matches!(
602        ty,
603        JoinType::Inner
604            | JoinType::LeftOuter
605            | JoinType::LeftSemi
606            | JoinType::LeftAnti
607            | JoinType::AsofInner
608            | JoinType::AsofLeftOuter
609    )
610}
611
612pub fn can_push_right_from_filter(ty: JoinType) -> bool {
613    matches!(
614        ty,
615        JoinType::Inner
616            | JoinType::RightOuter
617            | JoinType::RightSemi
618            | JoinType::RightAnti
619            | JoinType::AsofInner
620    )
621}
622
623pub fn can_push_on_from_filter(ty: JoinType) -> bool {
624    matches!(
625        ty,
626        JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi
627    )
628}
629
630pub fn can_push_left_from_on(ty: JoinType) -> bool {
631    matches!(
632        ty,
633        JoinType::Inner
634            | JoinType::RightOuter
635            | JoinType::LeftSemi
636            | JoinType::AsofInner
637            | JoinType::AsofLeftOuter
638    )
639}
640
641pub fn can_push_right_from_on(ty: JoinType) -> bool {
642    matches!(
643        ty,
644        JoinType::Inner
645            | JoinType::LeftOuter
646            | JoinType::RightSemi
647            | JoinType::AsofInner
648            | JoinType::AsofLeftOuter
649    )
650}