risingwave_frontend/optimizer/rule/
apply_join_transpose_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18use risingwave_common::types::DataType::Boolean;
19use risingwave_pb::plan_common::JoinType;
20
21use super::{BoxedRule, Rule};
22use crate::expr::{
23    CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall,
24    InputRef,
25};
26use crate::optimizer::PlanRef;
27use crate::optimizer::plan_node::generic::GenericPlanRef;
28use crate::optimizer::plan_node::{
29    LogicalApply, LogicalFilter, LogicalJoin, PlanTreeNode, PlanTreeNodeBinary,
30};
31use crate::optimizer::plan_visitor::{ExprCorrelatedIdFinder, PlanCorrelatedIdFinder};
32use crate::optimizer::rule::apply_offset_rewriter::ApplyCorrelatedIndicesConverter;
33use crate::utils::{ColIndexMapping, Condition};
34
35/// Transpose `LogicalApply` and `LogicalJoin`.
36///
37/// Before:
38///
39/// ```text
40///     LogicalApply
41///    /            \
42///  Domain      LogicalJoin
43///                /      \
44///               T1     T2
45/// ```
46///
47/// `push_apply_both_side`:
48///
49/// D Apply (T1 join< p > T2)  ->  (D Apply T1) join< p and natural join D > (D Apply T2)
50///
51/// After:
52///
53/// ```text
54///           LogicalJoin
55///         /            \
56///  LogicalApply     LogicalApply
57///   /      \           /      \
58/// Domain   T1        Domain   T2
59/// ```
60///
61/// `push_apply_left_side`:
62///
63/// D Apply (T1 join< p > T2)  ->  (D Apply T1) join< p > T2
64///
65/// After:
66///
67/// ```text
68///        LogicalJoin
69///      /            \
70///  LogicalApply    T2
71///   /      \
72/// Domain   T1
73/// ```
74///
75/// `push_apply_right_side`:
76///
77/// D Apply (T1 join< p > T2)  ->  T1 join< p > (D Apply T2)
78///
79/// After:
80///
81/// ```text
82///        LogicalJoin
83///      /            \
84///    T1         LogicalApply
85///                /      \
86///              Domain   T2
87/// ```
88pub struct ApplyJoinTransposeRule {}
89impl Rule for ApplyJoinTransposeRule {
90    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
91        let apply: &LogicalApply = plan.as_logical_apply()?;
92        let (
93            apply_left,
94            apply_right,
95            apply_on,
96            apply_join_type,
97            correlated_id,
98            correlated_indices,
99            max_one_row,
100        ) = apply.clone().decompose();
101
102        if max_one_row {
103            return None;
104        }
105
106        assert_eq!(apply_join_type, JoinType::Inner);
107        let join: &LogicalJoin = apply_right.as_logical_join()?;
108
109        let mut finder = ExprCorrelatedIdFinder::default();
110        join.on().visit_expr(&mut finder);
111        let join_cond_has_correlated_id = finder.contains(&correlated_id);
112        let join_left_has_correlated_id =
113            PlanCorrelatedIdFinder::find_correlated_id(join.left(), &correlated_id);
114        let join_right_has_correlated_id =
115            PlanCorrelatedIdFinder::find_correlated_id(join.right(), &correlated_id);
116
117        // Shortcut
118        // Check whether correlated_input_ref with same correlated_id exists below apply.
119        // If no, bail out and leave for `ApplyEliminateRule` to deal with.
120        if !join_cond_has_correlated_id
121            && !join_left_has_correlated_id
122            && !join_right_has_correlated_id
123        {
124            return None;
125        }
126
127        // ApplyJoinTransposeRule requires the join containing no output indices, so make sure ProjectJoinSeparateRule is always applied before this rule.
128        // As this rule will be applied until we reach the fixed point, if the join has output indices, apply ProjectJoinSeparateRule first and return is safety.
129        if !join.output_indices_are_trivial() {
130            let new_apply_right = crate::optimizer::rule::ProjectJoinSeparateRule::create()
131                .apply(join.clone().into())
132                .unwrap();
133            return Some(apply.clone_with_inputs(&[apply_left, new_apply_right]));
134        }
135
136        let (push_left, push_right) = match join.join_type() {
137            // `LeftSemi`, `LeftAnti`, `LeftOuter` can only push to left side if it's right side has
138            // no correlated id. Otherwise push to both sides.
139            JoinType::LeftSemi
140            | JoinType::LeftAnti
141            | JoinType::LeftOuter
142            | JoinType::AsofLeftOuter => {
143                if !join_right_has_correlated_id {
144                    (true, false)
145                } else {
146                    (true, true)
147                }
148            }
149            // `RightSemi`, `RightAnti`, `RightOuter` can only push to right side if it's left side
150            // has no correlated id. Otherwise push to both sides.
151            JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => {
152                if !join_left_has_correlated_id {
153                    (false, true)
154                } else {
155                    (true, true)
156                }
157            }
158            // `Inner` can push to one side if the other side is not dependent on it.
159            JoinType::Inner | JoinType::AsofInner => {
160                if join_cond_has_correlated_id
161                    && !join_right_has_correlated_id
162                    && !join_left_has_correlated_id
163                {
164                    (true, false)
165                } else {
166                    (join_left_has_correlated_id, join_right_has_correlated_id)
167                }
168            }
169            // `FullOuter` should always push to both sides.
170            JoinType::FullOuter => (true, true),
171            JoinType::Unspecified => unreachable!(),
172        };
173
174        let out = if push_left && push_right {
175            self.push_apply_both_side(
176                apply_left,
177                join,
178                apply_on,
179                apply_join_type,
180                correlated_id,
181                correlated_indices,
182            )
183        } else if push_left {
184            self.push_apply_left_side(
185                apply_left,
186                join,
187                apply_on,
188                apply_join_type,
189                correlated_id,
190                correlated_indices,
191            )
192        } else if push_right {
193            self.push_apply_right_side(
194                apply_left,
195                join,
196                apply_on,
197                apply_join_type,
198                correlated_id,
199                correlated_indices,
200            )
201        } else {
202            unreachable!();
203        };
204        assert_eq!(out.schema(), plan.schema());
205        Some(out)
206    }
207}
208
209impl ApplyJoinTransposeRule {
210    fn push_apply_left_side(
211        &self,
212        apply_left: PlanRef,
213        join: &LogicalJoin,
214        apply_on: Condition,
215        apply_join_type: JoinType,
216        correlated_id: CorrelatedId,
217        correlated_indices: Vec<usize>,
218    ) -> PlanRef {
219        let apply_left_len = apply_left.schema().len();
220        let join_left_len = join.left().schema().len();
221        let mut rewriter = Rewriter {
222            join_left_len,
223            join_left_offset: apply_left_len as isize,
224            join_right_offset: apply_left_len as isize,
225            index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
226                &correlated_indices,
227            ),
228            correlated_id,
229        };
230
231        // Rewrite join on condition
232        let new_join_condition = Condition {
233            conjunctions: join
234                .on()
235                .clone()
236                .into_iter()
237                .map(|expr| rewriter.rewrite_expr(expr))
238                .collect_vec(),
239        };
240
241        let mut left_apply_condition: Vec<ExprImpl> = vec![];
242        let mut other_condition: Vec<ExprImpl> = vec![];
243
244        match join.join_type() {
245            JoinType::LeftSemi | JoinType::LeftAnti => {
246                left_apply_condition.extend(apply_on);
247            }
248            JoinType::Inner
249            | JoinType::LeftOuter
250            | JoinType::RightOuter
251            | JoinType::FullOuter
252            | JoinType::AsofInner
253            | JoinType::AsofLeftOuter => {
254                let apply_len = apply_left_len + join.schema().len();
255                let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
256                d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
257
258                let (left, other): (Vec<_>, Vec<_>) = apply_on
259                    .into_iter()
260                    .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t1_bit_set));
261                left_apply_condition.extend(left);
262                other_condition.extend(other);
263            }
264            JoinType::RightSemi | JoinType::RightAnti | JoinType::Unspecified => unreachable!(),
265        }
266
267        let new_join_left = LogicalApply::create(
268            apply_left,
269            join.left(),
270            apply_join_type,
271            Condition {
272                conjunctions: left_apply_condition,
273            },
274            correlated_id,
275            correlated_indices,
276            false,
277        );
278
279        let new_join = LogicalJoin::new(
280            new_join_left,
281            join.right(),
282            join.join_type(),
283            new_join_condition,
284        );
285
286        // Leave other condition for predicate push down to deal with
287        LogicalFilter::create(
288            new_join.into(),
289            Condition {
290                conjunctions: other_condition,
291            },
292        )
293    }
294
295    fn push_apply_right_side(
296        &self,
297        apply_left: PlanRef,
298        join: &LogicalJoin,
299        apply_on: Condition,
300        apply_join_type: JoinType,
301        correlated_id: CorrelatedId,
302        correlated_indices: Vec<usize>,
303    ) -> PlanRef {
304        let apply_left_len = apply_left.schema().len();
305        let join_left_len = join.left().schema().len();
306        let mut rewriter = Rewriter {
307            join_left_len,
308            join_left_offset: 0,
309            join_right_offset: apply_left_len as isize,
310            index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
311                &correlated_indices,
312            ),
313            correlated_id,
314        };
315
316        // Rewrite join on condition
317        let new_join_condition = Condition {
318            conjunctions: join
319                .on()
320                .clone()
321                .into_iter()
322                .map(|expr| rewriter.rewrite_expr(expr))
323                .collect_vec(),
324        };
325
326        let mut right_apply_condition: Vec<ExprImpl> = vec![];
327        let mut other_condition: Vec<ExprImpl> = vec![];
328
329        match join.join_type() {
330            JoinType::RightSemi | JoinType::RightAnti => {
331                right_apply_condition.extend(apply_on);
332            }
333            JoinType::Inner
334            | JoinType::LeftOuter
335            | JoinType::RightOuter
336            | JoinType::FullOuter
337            | JoinType::AsofInner
338            | JoinType::AsofLeftOuter => {
339                let apply_len = apply_left_len + join.schema().len();
340                let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
341                d_t2_bit_set.set_range(0..apply_left_len, true);
342                d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
343
344                let (right, other): (Vec<_>, Vec<_>) = apply_on
345                    .into_iter()
346                    .partition(|expr| expr.collect_input_refs(apply_len).is_subset(&d_t2_bit_set));
347                right_apply_condition.extend(right);
348                other_condition.extend(other);
349
350                // rewrite right condition
351                let mut right_apply_condition_rewriter = Rewriter {
352                    join_left_len: apply_left_len,
353                    join_left_offset: 0,
354                    join_right_offset: -(join_left_len as isize),
355                    index_mapping: ColIndexMapping::empty(0, 0),
356                    correlated_id,
357                };
358
359                right_apply_condition = right_apply_condition
360                    .into_iter()
361                    .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
362                    .collect_vec();
363            }
364            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Unspecified => unreachable!(),
365        }
366
367        let new_join_right = LogicalApply::create(
368            apply_left,
369            join.right(),
370            apply_join_type,
371            Condition {
372                conjunctions: right_apply_condition,
373            },
374            correlated_id,
375            correlated_indices,
376            false,
377        );
378        let (output_indices, target_size) = {
379            let (apply_left_len, join_right_len) = match apply_join_type {
380                JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
381                JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
382                _ => (apply_left_len, join.right().schema().len()),
383            };
384
385            let left_iter = join_left_len..join_left_len + apply_left_len;
386            let right_iter = (0..join_left_len).chain(
387                join_left_len + apply_left_len..join_left_len + apply_left_len + join_right_len,
388            );
389
390            let output_indices: Vec<_> = match join.join_type() {
391                JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
392                JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
393                _ => left_iter.chain(right_iter).collect(),
394            };
395
396            let target_size = join_left_len + apply_left_len + join_right_len;
397            (output_indices, target_size)
398        };
399        let mut output_indices_mapping = ColIndexMapping::new(
400            output_indices.iter().map(|x| Some(*x)).collect(),
401            target_size,
402        );
403        let new_join = LogicalJoin::new(
404            join.left(),
405            new_join_right,
406            join.join_type(),
407            new_join_condition,
408        )
409        .clone_with_output_indices(output_indices);
410
411        // Leave other condition for predicate push down to deal with
412        LogicalFilter::create(
413            new_join.into(),
414            Condition {
415                conjunctions: other_condition,
416            }
417            .rewrite_expr(&mut output_indices_mapping),
418        )
419    }
420
421    fn push_apply_both_side(
422        &self,
423        apply_left: PlanRef,
424        join: &LogicalJoin,
425        apply_on: Condition,
426        apply_join_type: JoinType,
427        correlated_id: CorrelatedId,
428        correlated_indices: Vec<usize>,
429    ) -> PlanRef {
430        let apply_left_len = apply_left.schema().len();
431        let join_left_len = join.left().schema().len();
432        let mut rewriter = Rewriter {
433            join_left_len,
434            join_left_offset: apply_left_len as isize,
435            join_right_offset: 2 * apply_left_len as isize,
436            index_mapping: ApplyCorrelatedIndicesConverter::convert_to_index_mapping(
437                &correlated_indices,
438            ),
439            correlated_id,
440        };
441
442        // Rewrite join on condition and add natural join condition
443        let natural_conjunctions = apply_left
444            .schema()
445            .fields
446            .iter()
447            .enumerate()
448            .map(|(i, field)| {
449                Self::create_null_safe_equal_expr(
450                    i,
451                    field.data_type.clone(),
452                    i + join_left_len + apply_left_len,
453                    field.data_type.clone(),
454                )
455            })
456            .collect_vec();
457        let new_join_condition = Condition {
458            conjunctions: join
459                .on()
460                .clone()
461                .into_iter()
462                .map(|expr| rewriter.rewrite_expr(expr))
463                .chain(natural_conjunctions)
464                .collect_vec(),
465        };
466
467        let mut left_apply_condition: Vec<ExprImpl> = vec![];
468        let mut right_apply_condition: Vec<ExprImpl> = vec![];
469        let mut other_condition: Vec<ExprImpl> = vec![];
470
471        match join.join_type() {
472            JoinType::LeftSemi | JoinType::LeftAnti => {
473                left_apply_condition.extend(apply_on);
474            }
475            JoinType::RightSemi | JoinType::RightAnti => {
476                right_apply_condition.extend(apply_on);
477            }
478            JoinType::Inner
479            | JoinType::LeftOuter
480            | JoinType::RightOuter
481            | JoinType::FullOuter
482            | JoinType::AsofInner
483            | JoinType::AsofLeftOuter => {
484                let apply_len = apply_left_len + join.schema().len();
485                let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len);
486                let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len);
487                d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true);
488                d_t2_bit_set.set_range(0..apply_left_len, true);
489                d_t2_bit_set.set_range(apply_left_len + join_left_len..apply_len, true);
490
491                for (key, group) in &apply_on.into_iter().chunk_by(|expr| {
492                    let collect_bit_set = expr.collect_input_refs(apply_len);
493                    if collect_bit_set.is_subset(&d_t1_bit_set) {
494                        0
495                    } else if collect_bit_set.is_subset(&d_t2_bit_set) {
496                        1
497                    } else {
498                        2
499                    }
500                }) {
501                    let vec = group.collect_vec();
502                    match key {
503                        0 => left_apply_condition.extend(vec),
504                        1 => right_apply_condition.extend(vec),
505                        2 => other_condition.extend(vec),
506                        _ => unreachable!(),
507                    }
508                }
509
510                // Rewrite right condition
511                let mut right_apply_condition_rewriter = Rewriter {
512                    join_left_len: apply_left_len,
513                    join_left_offset: 0,
514                    join_right_offset: -(join_left_len as isize),
515                    index_mapping: ColIndexMapping::empty(0, 0),
516                    correlated_id,
517                };
518
519                right_apply_condition = right_apply_condition
520                    .into_iter()
521                    .map(|expr| right_apply_condition_rewriter.rewrite_expr(expr))
522                    .collect_vec();
523            }
524            JoinType::Unspecified => unreachable!(),
525        }
526
527        let new_join_left = LogicalApply::create(
528            apply_left.clone(),
529            join.left(),
530            apply_join_type,
531            Condition {
532                conjunctions: left_apply_condition,
533            },
534            correlated_id,
535            correlated_indices.clone(),
536            false,
537        );
538        let new_join_right = LogicalApply::create(
539            apply_left,
540            join.right(),
541            apply_join_type,
542            Condition {
543                conjunctions: right_apply_condition,
544            },
545            correlated_id,
546            correlated_indices,
547            false,
548        );
549
550        let (output_indices, target_size) = {
551            let (apply_left_len, join_right_len) = match apply_join_type {
552                JoinType::LeftSemi | JoinType::LeftAnti => (apply_left_len, 0),
553                JoinType::RightSemi | JoinType::RightAnti => (0, join.right().schema().len()),
554                _ => (apply_left_len, join.right().schema().len()),
555            };
556
557            let left_iter = 0..join_left_len + apply_left_len;
558            let right_iter = join_left_len + apply_left_len * 2
559                ..join_left_len + apply_left_len * 2 + join_right_len;
560
561            let output_indices: Vec<_> = match join.join_type() {
562                JoinType::LeftSemi | JoinType::LeftAnti => left_iter.collect(),
563                JoinType::RightSemi | JoinType::RightAnti => right_iter.collect(),
564                _ => left_iter.chain(right_iter).collect(),
565            };
566
567            let target_size = join_left_len + apply_left_len * 2 + join_right_len;
568            (output_indices, target_size)
569        };
570        let new_join = LogicalJoin::new(
571            new_join_left,
572            new_join_right,
573            join.join_type(),
574            new_join_condition,
575        )
576        .clone_with_output_indices(output_indices.clone());
577
578        match join.join_type() {
579            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti => {
580                new_join.into()
581            }
582            JoinType::Inner
583            | JoinType::LeftOuter
584            | JoinType::RightOuter
585            | JoinType::FullOuter
586            | JoinType::AsofInner
587            | JoinType::AsofLeftOuter => {
588                let mut output_indices_mapping = ColIndexMapping::new(
589                    output_indices.iter().map(|x| Some(*x)).collect(),
590                    target_size,
591                );
592                // Leave other condition for predicate push down to deal with
593                LogicalFilter::create(
594                    new_join.into(),
595                    Condition {
596                        conjunctions: other_condition,
597                    }
598                    .rewrite_expr(&mut output_indices_mapping),
599                )
600            }
601            JoinType::Unspecified => unreachable!(),
602        }
603    }
604
605    fn create_null_safe_equal_expr(
606        left: usize,
607        left_data_type: DataType,
608        right: usize,
609        right_data_type: DataType,
610    ) -> ExprImpl {
611        // use null-safe equal
612        ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
613            ExprType::IsNotDistinctFrom,
614            vec![
615                ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
616                ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
617            ],
618            Boolean,
619        )))
620    }
621}
622
623impl ApplyJoinTransposeRule {
624    pub fn create() -> BoxedRule {
625        Box::new(ApplyJoinTransposeRule {})
626    }
627}
628
629/// Convert `CorrelatedInputRef` to `InputRef` and shift `InputRef` with offset.
630struct Rewriter {
631    join_left_len: usize,
632    join_left_offset: isize,
633    join_right_offset: isize,
634    index_mapping: ColIndexMapping,
635    correlated_id: CorrelatedId,
636}
637impl ExprRewriter for Rewriter {
638    fn rewrite_correlated_input_ref(
639        &mut self,
640        correlated_input_ref: CorrelatedInputRef,
641    ) -> ExprImpl {
642        if correlated_input_ref.correlated_id() == self.correlated_id {
643            InputRef::new(
644                self.index_mapping.map(correlated_input_ref.index()),
645                correlated_input_ref.return_type(),
646            )
647            .into()
648        } else {
649            correlated_input_ref.into()
650        }
651    }
652
653    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
654        if input_ref.index < self.join_left_len {
655            InputRef::new(
656                (input_ref.index() as isize + self.join_left_offset) as usize,
657                input_ref.return_type(),
658            )
659            .into()
660        } else {
661            InputRef::new(
662                (input_ref.index() as isize + self.join_right_offset) as usize,
663                input_ref.return_type(),
664            )
665            .into()
666        }
667    }
668}