risingwave_frontend/optimizer/rule/
apply_join_transpose_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18use risingwave_common::types::DataType::Boolean;
19use risingwave_pb::plan_common::JoinType;
20
21use super::prelude::{PlanRef, *};
22use crate::expr::{
23    CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall,
24    InputRef,
25};
26use crate::optimizer::plan_node::generic::GenericPlanRef;
27use crate::optimizer::plan_node::{
28    LogicalApply, LogicalFilter, LogicalJoin, PlanTreeNode, PlanTreeNodeBinary,
29};
30use crate::optimizer::plan_visitor::{ExprCorrelatedIdFinder, PlanCorrelatedIdFinder};
31use crate::optimizer::rule::apply_offset_rewriter::ApplyCorrelatedIndicesConverter;
32use crate::utils::{ColIndexMapping, Condition};
33
34/// Transpose `LogicalApply` and `LogicalJoin`.
35///
36/// Before:
37///
38/// ```text
39///     LogicalApply
40///    /            \
41///  Domain      LogicalJoin
42///                /      \
43///               T1     T2
44/// ```
45///
46/// `push_apply_both_side`:
47///
48/// D Apply (T1 join< p > T2)  ->  (D Apply T1) join< p and natural join D > (D Apply T2)
49///
50/// After:
51///
52/// ```text
53///           LogicalJoin
54///         /            \
55///  LogicalApply     LogicalApply
56///   /      \           /      \
57/// Domain   T1        Domain   T2
58/// ```
59///
60/// `push_apply_left_side`:
61///
62/// D Apply (T1 join< p > T2)  ->  (D Apply T1) join< p > T2
63///
64/// After:
65///
66/// ```text
67///        LogicalJoin
68///      /            \
69///  LogicalApply    T2
70///   /      \
71/// Domain   T1
72/// ```
73///
74/// `push_apply_right_side`:
75///
76/// D Apply (T1 join< p > T2)  ->  T1 join< p > (D Apply T2)
77///
78/// After:
79///
80/// ```text
81///        LogicalJoin
82///      /            \
83///    T1         LogicalApply
84///                /      \
85///              Domain   T2
86/// ```
87pub struct ApplyJoinTransposeRule {}
88impl Rule<Logical> for ApplyJoinTransposeRule {
89    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
90        let apply: &LogicalApply = plan.as_logical_apply()?;
91        let (
92            apply_left,
93            apply_right,
94            apply_on,
95            apply_join_type,
96            correlated_id,
97            correlated_indices,
98            max_one_row,
99        ) = apply.clone().decompose();
100
101        if max_one_row {
102            return None;
103        }
104
105        assert_eq!(apply_join_type, JoinType::Inner);
106        let join: &LogicalJoin = apply_right.as_logical_join()?;
107
108        let mut finder = ExprCorrelatedIdFinder::default();
109        join.on().visit_expr(&mut finder);
110        let join_cond_has_correlated_id = finder.contains(&correlated_id);
111        let join_left_has_correlated_id =
112            PlanCorrelatedIdFinder::find_correlated_id(join.left(), &correlated_id);
113        let join_right_has_correlated_id =
114            PlanCorrelatedIdFinder::find_correlated_id(join.right(), &correlated_id);
115
116        // Shortcut
117        // Check whether correlated_input_ref with same correlated_id exists below apply.
118        // If no, bail out and leave for `ApplyEliminateRule` to deal with.
119        if !join_cond_has_correlated_id
120            && !join_left_has_correlated_id
121            && !join_right_has_correlated_id
122        {
123            return None;
124        }
125
126        // ApplyJoinTransposeRule requires the join containing no output indices, so make sure ProjectJoinSeparateRule is always applied before this rule.
127        // As this rule will be applied until we reach the fixed point, if the join has output indices, apply ProjectJoinSeparateRule first and return is safety.
128        if !join.output_indices_are_trivial() {
129            let new_apply_right = crate::optimizer::rule::ProjectJoinSeparateRule::create()
130                .apply(join.clone().into())
131                .unwrap();
132            return Some(apply.clone_with_inputs(&[apply_left, new_apply_right]));
133        }
134
135        let (push_left, push_right) = match join.join_type() {
136            // `LeftSemi`, `LeftAnti`, `LeftOuter` can only push to left side if it's right side has
137            // no correlated id. Otherwise push to both sides.
138            JoinType::LeftSemi
139            | JoinType::LeftAnti
140            | JoinType::LeftOuter
141            | JoinType::AsofLeftOuter => {
142                if !join_right_has_correlated_id {
143                    (true, false)
144                } else {
145                    (true, true)
146                }
147            }
148            // `RightSemi`, `RightAnti`, `RightOuter` can only push to right side if it's left side
149            // has no correlated id. Otherwise push to both sides.
150            JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
151                if !join_left_has_correlated_id {
152                    (false, true)
153                } else {
154                    (true, true)
155                }
156            }
157            // `Inner` can push to one side if the other side is not dependent on it.
158            JoinType::Inner | JoinType::AsofInner => {
159                if join_cond_has_correlated_id
160                    && !join_right_has_correlated_id
161                    && !join_left_has_correlated_id
162                {
163                    (true, false)
164                } else {
165                    (join_left_has_correlated_id, join_right_has_correlated_id)
166                }
167            }
168            // `FullOuter` should always push to both sides.
169            JoinType::FullOuter => (true, true),
170            JoinType::Unspecified => unreachable!(),
171        };
172
173        let out = if push_left && push_right {
174            self.push_apply_both_side(
175                apply_left,
176                join,
177                apply_on,
178                apply_join_type,
179                correlated_id,
180                correlated_indices,
181            )
182        } else if push_left {
183            self.push_apply_left_side(
184                apply_left,
185                join,
186                apply_on,
187                apply_join_type,
188                correlated_id,
189                correlated_indices,
190            )
191        } else if push_right {
192            self.push_apply_right_side(
193                apply_left,
194                join,
195                apply_on,
196                apply_join_type,
197                correlated_id,
198                correlated_indices,
199            )
200        } else {
201            unreachable!();
202        };
203        assert_eq!(out.schema(), plan.schema());
204        Some(out)
205    }
206}
207
208impl ApplyJoinTransposeRule {
209    fn push_apply_left_side(
210        &self,
211        apply_left: PlanRef,
212        join: &LogicalJoin,
213        apply_on: Condition,
214        apply_join_type: JoinType,
215        correlated_id: CorrelatedId,
216        correlated_indices: Vec<usize>,
217    ) -> PlanRef {
218        let apply_left_len = apply_left.schema().len();
219        let join_left_len = join.left().schema().len();
220        let mut rewriter = Rewriter {
221            join_left_len,
222            join_left_offset: apply_left_len as isize,
223            join_right_offset: apply_left_len as isize,
224            index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
225                &correlated_indices,
226            ),
227            correlated_id,
228        };
229
230        // Rewrite join on condition
231        let new_join_condition = Condition {
232            conjunctions: join
233                .on()
234                .clone()
235                .into_iter()
236                .map(|expr| rewriter.rewrite_expr(expr))
237                .collect_vec(),
238        };
239
240        let mut left_apply_condition: Vec<ExprImpl> = vec![];
241        let mut other_condition: Vec<ExprImpl> = vec![];
242
243        match join.join_type() {
244            JoinType::LeftSemi | JoinType::LeftAnti => {
245                left_apply_condition.extend(apply_on);
246            }
247            JoinType::Inner
248            | JoinType::LeftOuter
249            | JoinType::RightOuter
250            | JoinType::FullOuter
251            | JoinType::AsofInner
252            | JoinType::AsofLeftOuter => {
253                let apply_len = apply_left_len + join.schema().len();
254                let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
255                d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
256
257                let (left, other): (Vec<_>, Vec<_>) = apply_on
258                    .into_iter()
259                    .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t1_bit_set));
260                left_apply_condition.extend(left);
261                other_condition.extend(other);
262            }
263            JoinType::RightSemi | JoinType::RightAnti | JoinType::Unspecified => unreachable!(),
264        }
265
266        let new_join_left = LogicalApply::create(
267            apply_left,
268            join.left(),
269            apply_join_type,
270            Condition {
271                conjunctions: left_apply_condition,
272            },
273            correlated_id,
274            correlated_indices,
275            false,
276        );
277
278        let new_join = LogicalJoin::new(
279            new_join_left,
280            join.right(),
281            join.join_type(),
282            new_join_condition,
283        );
284
285        // Leave other condition for predicate push down to deal with
286        LogicalFilter::create(
287            new_join.into(),
288            Condition {
289                conjunctions: other_condition,
290            },
291        )
292    }
293
294    fn push_apply_right_side(
295        &self,
296        apply_left: PlanRef,
297        join: &LogicalJoin,
298        apply_on: Condition,
299        apply_join_type: JoinType,
300        correlated_id: CorrelatedId,
301        correlated_indices: Vec<usize>,
302    ) -> PlanRef {
303        let apply_left_len = apply_left.schema().len();
304        let join_left_len = join.left().schema().len();
305        let mut rewriter = Rewriter {
306            join_left_len,
307            join_left_offset: 0,
308            join_right_offset: apply_left_len as isize,
309            index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
310                &correlated_indices,
311            ),
312            correlated_id,
313        };
314
315        // Rewrite join on condition
316        let new_join_condition = Condition {
317            conjunctions: join
318                .on()
319                .clone()
320                .into_iter()
321                .map(|expr| rewriter.rewrite_expr(expr))
322                .collect_vec(),
323        };
324
325        let mut right_apply_condition: Vec<ExprImpl> = vec![];
326        let mut other_condition: Vec<ExprImpl> = vec![];
327
328        match join.join_type() {
329            JoinType::RightSemi | JoinType::RightAnti => {
330                right_apply_condition.extend(apply_on);
331            }
332            JoinType::Inner
333            | JoinType::LeftOuter
334            | JoinType::RightOuter
335            | JoinType::FullOuter
336            | JoinType::AsofInner
337            | JoinType::AsofLeftOuter => {
338                let apply_len = apply_left_len + join.schema().len();
339                let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
340                d_t2_bit_set.set_range(0..apply_left_len, true);
341                d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
342
343                let (right, other): (Vec<_>, Vec<_>) = apply_on
344                    .into_iter()
345                    .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t2_bit_set));
346                right_apply_condition.extend(right);
347                other_condition.extend(other);
348
349                // rewrite right condition
350                let mut right_apply_condition_rewriter = Rewriter {
351                    join_left_len: apply_left_len,
352                    join_left_offset: 0,
353                    join_right_offset: -(join_left_len as isize),
354                    index_mapping: ColIndexMapping::empty(0, 0),
355                    correlated_id,
356                };
357
358                right_apply_condition = right_apply_condition
359                    .into_iter()
360                    .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
361                    .collect_vec();
362            }
363            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Unspecified => unreachable!(),
364        }
365
366        let new_join_right = LogicalApply::create(
367            apply_left,
368            join.right(),
369            apply_join_type,
370            Condition {
371                conjunctions: right_apply_condition,
372            },
373            correlated_id,
374            correlated_indices,
375            false,
376        );
377        let (output_indices, target_size) = {
378            let (apply_left_len, join_right_len) = match apply_join_type {
379                JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
380                JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
381                _ => (apply_left_len, join.right().schema().len()),
382            };
383
384            let left_iter = join_left_len..join_left_len + apply_left_len;
385            let right_iter = (0..join_left_len).chain(
386                join_left_len + apply_left_len..join_left_len + apply_left_len + join_right_len,
387            );
388
389            let output_indices: Vec<_> = match join.join_type() {
390                JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
391                JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
392                _ => left_iter.chain(right_iter).collect(),
393            };
394
395            let target_size = join_left_len + apply_left_len + join_right_len;
396            (output_indices, target_size)
397        };
398        let mut output_indices_mapping = ColIndexMapping::new(
399            output_indices.iter().map(|x| Some(*x)).collect(),
400            target_size,
401        );
402        let new_join = LogicalJoin::new(
403            join.left(),
404            new_join_right,
405            join.join_type(),
406            new_join_condition,
407        )
408        .clone_with_output_indices(output_indices);
409
410        // Leave other condition for predicate push down to deal with
411        LogicalFilter::create(
412            new_join.into(),
413            Condition {
414                conjunctions: other_condition,
415            }
416            .rewrite_expr(&mut output_indices_mapping),
417        )
418    }
419
420    fn push_apply_both_side(
421        &self,
422        apply_left: PlanRef,
423        join: &LogicalJoin,
424        apply_on: Condition,
425        apply_join_type: JoinType,
426        correlated_id: CorrelatedId,
427        correlated_indices: Vec<usize>,
428    ) -> PlanRef {
429        let apply_left_len = apply_left.schema().len();
430        let join_left_len = join.left().schema().len();
431        let mut rewriter = Rewriter {
432            join_left_len,
433            join_left_offset: apply_left_len as isize,
434            join_right_offset: 2 * apply_left_len as isize,
435            index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
436                &correlated_indices,
437            ),
438            correlated_id,
439        };
440
441        // Rewrite join on condition and add natural join condition
442        let natural_conjunctions = apply_left
443            .schema()
444            .fields
445            .iter()
446            .enumerate()
447            .map(|(i, field)| {
448                Self::create_null_safe_equal_expr(
449                    i,
450                    field.data_type.clone(),
451                    i + join_left_len + apply_left_len,
452                    field.data_type.clone(),
453                )
454            })
455            .collect_vec();
456        let new_join_condition = Condition {
457            conjunctions: join
458                .on()
459                .clone()
460                .into_iter()
461                .map(|expr| rewriter.rewrite_expr(expr))
462                .chain(natural_conjunctions)
463                .collect_vec(),
464        };
465
466        let mut left_apply_condition: Vec<ExprImpl> = vec![];
467        let mut right_apply_condition: Vec<ExprImpl> = vec![];
468        let mut other_condition: Vec<ExprImpl> = vec![];
469
470        match join.join_type() {
471            JoinType::LeftSemi | JoinType::LeftAnti => {
472                left_apply_condition.extend(apply_on);
473            }
474            JoinType::RightSemi | JoinType::RightAnti => {
475                right_apply_condition.extend(apply_on);
476            }
477            JoinType::Inner
478            | JoinType::LeftOuter
479            | JoinType::RightOuter
480            | JoinType::FullOuter
481            | JoinType::AsofInner
482            | JoinType::AsofLeftOuter => {
483                let apply_len = apply_left_len + join.schema().len();
484                let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
485                let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
486                d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
487                d_t2_bit_set.set_range(0..apply_left_len, true);
488                d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
489
490                for (key, group) in &apply_on.into_iter().chunk_by(|expr| {
491                    let collect_bit_set = expr.collect_input_refs(apply_len);
492                    if collect_bit_set.is_subset(&d_t1_bit_set) {
493                        0
494                    } else if collect_bit_set.is_subset(&d_t2_bit_set) {
495                        1
496                    } else {
497                        2
498                    }
499                }) {
500                    let vec = group.collect_vec();
501                    match key {
502                        0 => left_apply_condition.extend(vec),
503                        1 => right_apply_condition.extend(vec),
504                        2 => other_condition.extend(vec),
505                        _ => unreachable!(),
506                    }
507                }
508
509                // Rewrite right condition
510                let mut right_apply_condition_rewriter = Rewriter {
511                    join_left_len: apply_left_len,
512                    join_left_offset: 0,
513                    join_right_offset: -(join_left_len as isize),
514                    index_mapping: ColIndexMapping::empty(0, 0),
515                    correlated_id,
516                };
517
518                right_apply_condition = right_apply_condition
519                    .into_iter()
520                    .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
521                    .collect_vec();
522            }
523            JoinType::Unspecified => unreachable!(),
524        }
525
526        let new_join_left = LogicalApply::create(
527            apply_left.clone(),
528            join.left(),
529            apply_join_type,
530            Condition {
531                conjunctions: left_apply_condition,
532            },
533            correlated_id,
534            correlated_indices.clone(),
535            false,
536        );
537        let new_join_right = LogicalApply::create(
538            apply_left,
539            join.right(),
540            apply_join_type,
541            Condition {
542                conjunctions: right_apply_condition,
543            },
544            correlated_id,
545            correlated_indices,
546            false,
547        );
548
549        let (output_indices, target_size) = {
550            let (apply_left_len, join_right_len) = match apply_join_type {
551                JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
552                JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
553                _ => (apply_left_len, join.right().schema().len()),
554            };
555
556            let left_iter = 0..join_left_len + apply_left_len;
557            let right_iter = join_left_len + apply_left_len * 2
558                ..join_left_len + apply_left_len * 2 + join_right_len;
559
560            let output_indices: Vec<_> = match join.join_type() {
561                JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
562                JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
563                _ => left_iter.chain(right_iter).collect(),
564            };
565
566            let target_size = join_left_len + apply_left_len * 2 + join_right_len;
567            (output_indices, target_size)
568        };
569        let new_join = LogicalJoin::new(
570            new_join_left,
571            new_join_right,
572            join.join_type(),
573            new_join_condition,
574        )
575        .clone_with_output_indices(output_indices.clone());
576
577        match join.join_type() {
578            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti => {
579                new_join.into()
580            }
581            JoinType::Inner
582            | JoinType::LeftOuter
583            | JoinType::RightOuter
584            | JoinType::FullOuter
585            | JoinType::AsofInner
586            | JoinType::AsofLeftOuter => {
587                let mut output_indices_mapping = ColIndexMapping::new(
588                    output_indices.iter().map(|x| Some(*x)).collect(),
589                    target_size,
590                );
591                // Leave other condition for predicate push down to deal with
592                LogicalFilter::create(
593                    new_join.into(),
594                    Condition {
595                        conjunctions: other_condition,
596                    }
597                    .rewrite_expr(&mut output_indices_mapping),
598                )
599            }
600            JoinType::Unspecified => unreachable!(),
601        }
602    }
603
604    fn create_null_safe_equal_expr(
605        left: usize,
606        left_data_type: DataType,
607        right: usize,
608        right_data_type: DataType,
609    ) -> ExprImpl {
610        // use null-safe equal
611        ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
612            ExprType::IsNotDistinctFrom,
613            vec![
614                ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
615                ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
616            ],
617            Boolean,
618        )))
619    }
620}
621
622impl ApplyJoinTransposeRule {
623    pub fn create() -> BoxedRule {
624        Box::new(ApplyJoinTransposeRule {})
625    }
626}
627
628/// Convert `CorrelatedInputRef` to `InputRef` and shift `InputRef` with offset.
629struct Rewriter {
630    join_left_len: usize,
631    join_left_offset: isize,
632    join_right_offset: isize,
633    index_mapping: ColIndexMapping,
634    correlated_id: CorrelatedId,
635}
636impl ExprRewriter for Rewriter {
637    fn rewrite_correlated_input_ref(
638        &mut self,
639        correlated_input_ref: CorrelatedInputRef,
640    ) -> ExprImpl {
641        if correlated_input_ref.correlated_id() == self.correlated_id {
642            InputRef::new(
643                self.index_mapping.map(correlated_input_ref.index()),
644                correlated_input_ref.return_type(),
645            )
646            .into()
647        } else {
648            correlated_input_ref.into()
649        }
650    }
651
652    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
653        if input_ref.index < self.join_left_len {
654            InputRef::new(
655                (input_ref.index() as isize + self.join_left_offset) as usize,
656                input_ref.return_type(),
657            )
658            .into()
659        } else {
660            InputRef::new(
661                (input_ref.index() as isize + self.join_right_offset) as usize,
662                input_ref.return_type(),
663            )
664            .into()
665        }
666    }
667}