1use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31 ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeBinary, PredicatePushdown,
32 StreamHashJoin, StreamProject, ToBatch, ToStream, generic,
33};
34use crate::error::{ErrorCode, Result, RwError};
35use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::DynamicFilter;
38use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
39use crate::optimizer::plan_node::utils::IndicesDisplay;
40use crate::optimizer::plan_node::{
41 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
42 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
43 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
44};
45use crate::optimizer::plan_visitor::LogicalCardinalityExt;
46use crate::optimizer::property::{Distribution, Order, RequiredDist};
47use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
48
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct LogicalJoin {
57 pub base: PlanBase<Logical>,
58 core: generic::Join<PlanRef>,
59}
60
61impl Distill for LogicalJoin {
62 fn distill<'a>(&self) -> XmlNode<'a> {
63 let verbose = self.base.ctx().is_explain_verbose();
64 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
65 vec.push(("type", Pretty::debug(&self.join_type())));
66
67 let concat_schema = self.core.concat_schema();
68 let cond = Pretty::debug(&ConditionDisplay {
69 condition: self.on(),
70 input_schema: &concat_schema,
71 });
72 vec.push(("on", cond));
73
74 if verbose {
75 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
76 vec.push(("output", data));
77 }
78
79 childless_record("LogicalJoin", vec)
80 }
81}
82
83impl LogicalJoin {
84 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
85 let core = generic::Join::with_full_output(left, right, join_type, on);
86 Self::with_core(core)
87 }
88
89 pub(crate) fn with_output_indices(
90 left: PlanRef,
91 right: PlanRef,
92 join_type: JoinType,
93 on: Condition,
94 output_indices: Vec<usize>,
95 ) -> Self {
96 let core = generic::Join::new(left, right, on, join_type, output_indices);
97 Self::with_core(core)
98 }
99
100 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
101 let base = PlanBase::new_logical_with_core(&core);
102 LogicalJoin { base, core }
103 }
104
105 pub fn create(
106 left: PlanRef,
107 right: PlanRef,
108 join_type: JoinType,
109 on_clause: ExprImpl,
110 ) -> PlanRef {
111 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
112 }
113
114 pub fn internal_column_num(&self) -> usize {
115 self.core.internal_column_num()
116 }
117
118 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
119 self.core.i2l_col_mapping_ignore_join_type()
120 }
121
122 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
123 self.core.i2r_col_mapping_ignore_join_type()
124 }
125
126 pub fn on(&self) -> &Condition {
128 &self.core.on
129 }
130
131 pub fn core(&self) -> &generic::Join<PlanRef> {
132 &self.core
133 }
134
135 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
137 let input_refs = self
138 .core
139 .on
140 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
141 let index_group = input_refs
142 .ones()
143 .chunk_by(|i| *i < self.core.left.schema().len());
144 let left_index = index_group
145 .into_iter()
146 .next()
147 .map_or(vec![], |group| group.1.collect_vec());
148 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
149 group
150 .1
151 .map(|i| i - self.core.left.schema().len())
152 .collect_vec()
153 });
154 (left_index, right_index)
155 }
156
157 pub fn join_type(&self) -> JoinType {
159 self.core.join_type
160 }
161
162 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
164 self.core.eq_indexes()
165 }
166
167 pub fn output_indices(&self) -> &Vec<usize> {
169 &self.core.output_indices
170 }
171
172 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
174 Self::with_core(generic::Join {
175 output_indices,
176 ..self.core.clone()
177 })
178 }
179
180 pub fn clone_with_cond(&self, on: Condition) -> Self {
182 Self::with_core(generic::Join {
183 on,
184 ..self.core.clone()
185 })
186 }
187
188 pub fn is_left_join(&self) -> bool {
189 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
190 }
191
192 pub fn is_right_join(&self) -> bool {
193 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
194 }
195
196 pub fn is_full_out(&self) -> bool {
197 self.core.is_full_out()
198 }
199
200 pub fn is_asof_join(&self) -> bool {
201 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
202 }
203
204 pub fn output_indices_are_trivial(&self) -> bool {
205 self.output_indices() == &(0..self.internal_column_num()).collect_vec()
206 }
207
208 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
213 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
214 JoinType::LeftOuter => (false, true),
215 JoinType::RightOuter => (true, false),
216 JoinType::FullOuter => (true, true),
217 _ => return join_type,
218 };
219
220 for expr in &predicate.conjunctions {
221 if let ExprImpl::FunctionCall(func) = expr {
222 match func.func_type() {
223 ExprType::Equal
224 | ExprType::NotEqual
225 | ExprType::LessThan
226 | ExprType::LessThanOrEqual
227 | ExprType::GreaterThan
228 | ExprType::GreaterThanOrEqual => {
229 for input in func.inputs() {
230 if let ExprImpl::InputRef(input) = input {
231 let idx = input.index;
232 if idx < left_col_num {
233 gen_null_in_left = false;
234 } else {
235 gen_null_in_right = false;
236 }
237 }
238 }
239 }
240 _ => {}
241 };
242 }
243 }
244
245 match (gen_null_in_left, gen_null_in_right) {
246 (true, true) => JoinType::FullOuter,
247 (true, false) => JoinType::RightOuter,
248 (false, true) => JoinType::LeftOuter,
249 (false, false) => JoinType::Inner,
250 }
251 }
252
253 fn to_batch_lookup_join_with_index_selection(
257 &self,
258 predicate: EqJoinPredicate,
259 logical_join: generic::Join<PlanRef>,
260 ) -> Result<Option<BatchLookupJoin>> {
261 match logical_join.join_type {
262 JoinType::Inner
263 | JoinType::LeftOuter
264 | JoinType::LeftSemi
265 | JoinType::LeftAnti
266 | JoinType::AsofInner
267 | JoinType::AsofLeftOuter => {}
268 _ => return Ok(None),
269 };
270
271 let right = self.right();
273 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
275 logical_scan
276 } else {
277 return Ok(None);
278 };
279
280 let mut result_plan = None;
281 if let Some(lookup_join) =
283 self.to_batch_lookup_join(predicate.clone(), logical_join.clone())?
284 {
285 result_plan = Some(lookup_join);
286 }
287
288 let indexes = logical_scan.indexes();
289 for index in indexes {
290 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
291 let index_scan: PlanRef = index_scan.into();
292 let that = self.clone_with_left_right(self.left(), index_scan.clone());
293 let mut new_logical_join = logical_join.clone();
294 new_logical_join.right = index_scan.to_batch().expect("index scan failed to batch");
295
296 if let Some(lookup_join) =
298 that.to_batch_lookup_join(predicate.clone(), new_logical_join)?
299 {
300 match &result_plan {
301 None => result_plan = Some(lookup_join),
302 Some(prev_lookup_join) => {
303 if prev_lookup_join.lookup_prefix_len()
305 < lookup_join.lookup_prefix_len()
306 {
307 result_plan = Some(lookup_join)
308 }
309 }
310 }
311 }
312 }
313 }
314
315 Ok(result_plan)
316 }
317
318 fn to_batch_lookup_join(
320 &self,
321 predicate: EqJoinPredicate,
322 logical_join: generic::Join<PlanRef>,
323 ) -> Result<Option<BatchLookupJoin>> {
324 match logical_join.join_type {
325 JoinType::Inner
326 | JoinType::LeftOuter
327 | JoinType::LeftSemi
328 | JoinType::LeftAnti
329 | JoinType::AsofInner
330 | JoinType::AsofLeftOuter => {}
331 _ => return Ok(None),
332 };
333
334 let right = self.right();
335 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
337 logical_scan
338 } else {
339 return Ok(None);
340 };
341 let table_desc = logical_scan.table_desc().clone();
342 let output_column_ids = logical_scan.output_column_ids();
343
344 let order_col_ids = table_desc.order_column_ids();
347 let order_key = table_desc.order_column_indices();
348 let dist_key = table_desc.distribution_key.clone();
349 let mut dist_key_in_order_key_pos = vec![];
351 for d in dist_key {
352 let pos = order_key
353 .iter()
354 .position(|&x| x == d)
355 .expect("dist_key must in order_key");
356 dist_key_in_order_key_pos.push(pos);
357 }
358 let shortest_prefix_len = dist_key_in_order_key_pos
360 .iter()
361 .max()
362 .map_or(0, |pos| pos + 1);
363
364 if shortest_prefix_len == 0 {
366 return Ok(None);
367 }
368
369 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
371 for order_col_id in order_col_ids {
372 let mut found = false;
373 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
374 if order_col_id == output_column_ids[eq_idx] {
375 reorder_idx.push(i);
376 found = true;
377 break;
378 }
379 }
380 if !found {
381 break;
382 }
383 }
384 if reorder_idx.len() < shortest_prefix_len {
385 return Ok(None);
386 }
387 let lookup_prefix_len = reorder_idx.len();
388 let predicate = predicate.reorder(&reorder_idx);
389
390 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
392 let o2r = if let Some(project_expr) = project_expr {
394 project_expr
395 .into_iter()
396 .map(|x| x.as_input_ref().unwrap().index)
397 .collect_vec()
398 } else {
399 (0..logical_scan.output_col_idx().len()).collect_vec()
400 };
401 let left_schema_len = logical_join.left.schema().len();
402
403 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
404 offset: left_schema_len,
405 mapping: o2r.clone(),
406 };
407
408 let new_eq_cond = predicate
409 .eq_cond()
410 .rewrite_expr(&mut join_predicate_rewriter);
411
412 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
413 offset: left_schema_len,
414 };
415
416 let new_other_cond = predicate
417 .other_cond()
418 .clone()
419 .rewrite_expr(&mut join_predicate_rewriter)
420 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
421
422 let new_join_on = new_eq_cond.and(new_other_cond);
423 let new_predicate = EqJoinPredicate::create(
424 left_schema_len,
425 new_scan.schema().len(),
426 new_join_on.clone(),
427 );
428
429 if !new_predicate.has_eq() {
432 return Ok(None);
433 }
434
435 let new_join_output_indices = logical_join
438 .output_indices
439 .iter()
440 .map(|&x| {
441 if x < left_schema_len {
442 x
443 } else {
444 o2r[x - left_schema_len] + left_schema_len
445 }
446 })
447 .collect_vec();
448
449 let new_scan_output_column_ids = new_scan.output_column_ids();
450 let as_of = new_scan.as_of.clone();
451
452 let new_logical_join = generic::Join::new(
454 logical_join.left,
455 new_scan.into(),
456 new_join_on,
457 logical_join.join_type,
458 new_join_output_indices,
459 );
460
461 let asof_desc = self
462 .is_asof_join()
463 .then(|| {
464 Self::get_inequality_desc_from_predicate(
465 predicate.other_cond().clone(),
466 left_schema_len,
467 )
468 })
469 .transpose()?;
470
471 Ok(Some(BatchLookupJoin::new(
472 new_logical_join,
473 new_predicate,
474 table_desc,
475 new_scan_output_column_ids,
476 lookup_prefix_len,
477 false,
478 as_of,
479 asof_desc,
480 )))
481 }
482
483 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
484 self.core.decompose()
485 }
486}
487
488impl PlanTreeNodeBinary for LogicalJoin {
489 fn left(&self) -> PlanRef {
490 self.core.left.clone()
491 }
492
493 fn right(&self) -> PlanRef {
494 self.core.right.clone()
495 }
496
497 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
498 Self::with_core(generic::Join {
499 left,
500 right,
501 ..self.core.clone()
502 })
503 }
504
505 fn rewrite_with_left_right(
506 &self,
507 left: PlanRef,
508 left_col_change: ColIndexMapping,
509 right: PlanRef,
510 right_col_change: ColIndexMapping,
511 ) -> (Self, ColIndexMapping) {
512 let (new_on, new_output_indices) = {
513 let (mut map, _) = left_col_change.clone().into_parts();
514 let (mut right_map, _) = right_col_change.clone().into_parts();
515 for i in right_map.iter_mut().flatten() {
516 *i += left.schema().len();
517 }
518 map.append(&mut right_map);
519 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
520
521 let new_output_indices = self
522 .output_indices()
523 .iter()
524 .map(|&i| mapping.map(i))
525 .collect::<Vec<_>>();
526 let new_on = self.on().clone().rewrite_expr(&mut mapping);
527 (new_on, new_output_indices)
528 };
529
530 let join = Self::with_output_indices(
531 left,
532 right,
533 self.join_type(),
534 new_on,
535 new_output_indices.clone(),
536 );
537
538 let new_i2o = ColIndexMapping::with_remaining_columns(
539 &new_output_indices,
540 join.internal_column_num(),
541 );
542
543 let old_o2i = self.core.o2i_col_mapping();
544
545 let old_o2l = old_o2i
546 .composite(&self.core.i2l_col_mapping())
547 .composite(&left_col_change);
548 let old_o2r = old_o2i
549 .composite(&self.core.i2r_col_mapping())
550 .composite(&right_col_change);
551 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
552 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
553
554 let out_col_change = old_o2l
555 .composite(&new_l2o)
556 .union(&old_o2r.composite(&new_r2o));
557 (join, out_col_change)
558 }
559}
560
561impl_plan_tree_node_for_binary! { LogicalJoin }
562
563impl ColPrunable for LogicalJoin {
564 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
565 let required_cols = required_cols
567 .iter()
568 .map(|i| self.output_indices()[*i])
569 .collect_vec();
570 let left_len = self.left().schema().fields.len();
571
572 let total_len = self.left().schema().len() + self.right().schema().len();
573 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
574
575 required_cols.iter().for_each(|&i| {
576 if self.is_right_join() {
577 resized_required_cols.insert(left_len + i);
578 } else {
579 resized_required_cols.insert(i);
580 }
581 });
582
583 let mut visitor = CollectInputRef::new(resized_required_cols);
586 self.on().visit_expr(&mut visitor);
587 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
588
589 let mut left_required_cols = Vec::new();
590 let mut right_required_cols = Vec::new();
591 left_right_required_cols.iter().for_each(|&i| {
592 if i < left_len {
593 left_required_cols.push(i);
594 } else {
595 right_required_cols.push(i - left_len);
596 }
597 });
598
599 let mut on = self.on().clone();
600 let mut mapping =
601 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
602 on = on.rewrite_expr(&mut mapping);
603
604 let new_output_indices = {
605 let required_inputs_in_output = if self.is_left_join() {
606 &left_required_cols
607 } else if self.is_right_join() {
608 &right_required_cols
609 } else {
610 &left_right_required_cols
611 };
612
613 let mapping =
614 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
615 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
616 };
617
618 LogicalJoin::with_output_indices(
619 self.left().prune_col(&left_required_cols, ctx),
620 self.right().prune_col(&right_required_cols, ctx),
621 self.join_type(),
622 on,
623 new_output_indices,
624 )
625 .into()
626 }
627}
628
629impl ExprRewritable for LogicalJoin {
630 fn has_rewritable_expr(&self) -> bool {
631 true
632 }
633
634 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
635 let mut core = self.core.clone();
636 core.rewrite_exprs(r);
637 Self {
638 base: self.base.clone_with_new_plan_id(),
639 core,
640 }
641 .into()
642 }
643}
644
645impl ExprVisitable for LogicalJoin {
646 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
647 self.core.visit_exprs(v);
648 }
649}
650
651fn derive_predicate_from_eq_condition(
669 expr: &ExprImpl,
670 eq_condition: &EqJoinPredicate,
671 col_num: usize,
672 expr_is_left: bool,
673) -> Option<ExprImpl> {
674 if expr.is_impure() {
675 return None;
676 }
677 let eq_indices = eq_condition
678 .eq_indexes_typed()
679 .iter()
680 .filter_map(|(l, r)| {
681 if l.return_type() != r.return_type() {
682 None
683 } else if expr_is_left {
684 Some(l.index())
685 } else {
686 Some(r.index())
687 }
688 })
689 .collect_vec();
690 if expr
691 .collect_input_refs(col_num)
692 .ones()
693 .any(|index| !eq_indices.contains(&index))
694 {
695 return None;
697 }
698 let other_side_mapping = if expr_is_left {
701 eq_condition.eq_indexes_typed().into_iter().collect()
702 } else {
703 eq_condition
704 .eq_indexes_typed()
705 .into_iter()
706 .map(|(x, y)| (y, x))
707 .collect()
708 };
709 struct InputRefsRewriter {
710 mapping: HashMap<InputRef, InputRef>,
711 }
712 impl ExprRewriter for InputRefsRewriter {
713 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
714 self.mapping[&input_ref].clone().into()
715 }
716 }
717 Some(
718 InputRefsRewriter {
719 mapping: other_side_mapping,
720 }
721 .rewrite_expr(expr.clone()),
722 )
723}
724
725struct LookupJoinPredicateRewriter {
727 offset: usize,
728 mapping: Vec<usize>,
729}
730impl ExprRewriter for LookupJoinPredicateRewriter {
731 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
732 if input_ref.index() < self.offset {
733 input_ref.into()
734 } else {
735 InputRef::new(
736 self.mapping[input_ref.index() - self.offset] + self.offset,
737 input_ref.return_type(),
738 )
739 .into()
740 }
741 }
742}
743
744struct LookupJoinScanPredicateRewriter {
746 offset: usize,
747}
748impl ExprRewriter for LookupJoinScanPredicateRewriter {
749 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
750 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
751 }
752}
753
754impl PredicatePushdown for LogicalJoin {
755 fn predicate_pushdown(
779 &self,
780 predicate: Condition,
781 ctx: &mut PredicatePushdownContext,
782 ) -> PlanRef {
783 let mut predicate = {
785 let mut mapping = self.core.o2i_col_mapping();
786 predicate.rewrite_expr(&mut mapping)
787 };
788
789 let left_col_num = self.left().schema().len();
790 let right_col_num = self.right().schema().len();
791 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
792
793 let push_down_temporal_predicate = !self.should_be_temporal_join();
794
795 let (left_from_filter, right_from_filter, on) = push_down_into_join(
796 &mut predicate,
797 left_col_num,
798 right_col_num,
799 join_type,
800 push_down_temporal_predicate,
801 );
802
803 let mut new_on = self.on().clone().and(on);
804 let (left_from_on, right_from_on) = push_down_join_condition(
805 &mut new_on,
806 left_col_num,
807 right_col_num,
808 join_type,
809 push_down_temporal_predicate,
810 );
811
812 let left_predicate = left_from_filter.and(left_from_on);
813 let right_predicate = right_from_filter.and(right_from_on);
814
815 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
817
818 let right_from_left = if matches!(
820 join_type,
821 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
822 ) {
823 Condition {
824 conjunctions: left_predicate
825 .conjunctions
826 .iter()
827 .filter_map(|expr| {
828 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
829 })
830 .collect(),
831 }
832 } else {
833 Condition::true_cond()
834 };
835
836 let left_from_right = if matches!(
838 join_type,
839 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
840 ) {
841 Condition {
842 conjunctions: right_predicate
843 .conjunctions
844 .iter()
845 .filter_map(|expr| {
846 derive_predicate_from_eq_condition(
847 expr,
848 &eq_condition,
849 right_col_num,
850 false,
851 )
852 })
853 .collect(),
854 }
855 } else {
856 Condition::true_cond()
857 };
858
859 let left_predicate = left_predicate.and(left_from_right);
860 let right_predicate = right_predicate.and(right_from_left);
861
862 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
863 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
864 let new_join = LogicalJoin::with_output_indices(
865 new_left,
866 new_right,
867 join_type,
868 new_on,
869 self.output_indices().clone(),
870 );
871
872 let mut mapping = self.core.i2o_col_mapping();
873 predicate = predicate.rewrite_expr(&mut mapping);
874 LogicalFilter::create(new_join.into(), predicate)
875 }
876}
877
878impl LogicalJoin {
879 fn get_stream_input_for_hash_join(
880 &self,
881 predicate: &EqJoinPredicate,
882 ctx: &mut ToStreamContext,
883 ) -> Result<(PlanRef, PlanRef)> {
884 use super::stream::prelude::*;
885
886 let mut right = self.right().to_stream_with_dist_required(
887 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
888 ctx,
889 )?;
890 let mut left = self.left();
891
892 let r2l = predicate.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
893 let l2r = predicate.l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
894
895 let right_dist = right.distribution();
896 match right_dist {
897 Distribution::HashShard(_) => {
898 let left_dist = r2l
899 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
900 left = left.to_stream_with_dist_required(&left_dist, ctx)?;
901 }
902 Distribution::UpstreamHashShard(_, _) => {
903 left = left.to_stream_with_dist_required(
904 &RequiredDist::shard_by_key(
905 self.left().schema().len(),
906 &predicate.left_eq_indexes(),
907 ),
908 ctx,
909 )?;
910 let left_dist = left.distribution();
911 match left_dist {
912 Distribution::HashShard(_) => {
913 let right_dist = l2r.rewrite_required_distribution(
914 &RequiredDist::PhysicalDist(left_dist.clone()),
915 );
916 right = right_dist.enforce_if_not_satisfies(right, &Order::any())?
917 }
918 Distribution::UpstreamHashShard(_, _) => {
919 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
920 .enforce_if_not_satisfies(left, &Order::any())?;
921 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
922 .enforce_if_not_satisfies(right, &Order::any())?;
923 }
924 _ => unreachable!(),
925 }
926 }
927 _ => unreachable!(),
928 }
929 Ok((left, right))
930 }
931
932 fn to_stream_hash_join(
933 &self,
934 predicate: EqJoinPredicate,
935 ctx: &mut ToStreamContext,
936 ) -> Result<PlanRef> {
937 use super::stream::prelude::*;
938
939 assert!(predicate.has_eq());
940 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
941
942 let logical_join = self.clone_with_left_right(left, right);
943
944 let stream_hash_join = StreamHashJoin::new(logical_join.core.clone(), predicate.clone());
953
954 let force_filter_inside_join = self
955 .base
956 .ctx()
957 .session_ctx()
958 .config()
959 .streaming_force_filter_inside_join();
960
961 let pull_filter = self.join_type() == JoinType::Inner
962 && stream_hash_join.eq_join_predicate().has_non_eq()
963 && stream_hash_join.inequality_pairs().is_empty()
964 && (!force_filter_inside_join);
965 if pull_filter {
966 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
967
968 let logical_join = logical_join.clone_with_output_indices(default_indices.clone());
970 let eq_cond = EqJoinPredicate::new(
971 Condition::true_cond(),
972 predicate.eq_keys().to_vec(),
973 self.left().schema().len(),
974 self.right().schema().len(),
975 );
976 let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
977 let hash_join = StreamHashJoin::new(logical_join.core, eq_cond).into();
978 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
979 let plan = StreamFilter::new(logical_filter).into();
980 if self.output_indices() != &default_indices {
981 let logical_project = generic::Project::with_mapping(
982 plan,
983 ColIndexMapping::with_remaining_columns(
984 self.output_indices(),
985 self.internal_column_num(),
986 ),
987 );
988 Ok(StreamProject::new(logical_project).into())
989 } else {
990 Ok(plan)
991 }
992 } else {
993 Ok(stream_hash_join.into())
994 }
995 }
996
997 fn should_be_temporal_join(&self) -> bool {
998 let right = self.right();
999 if let Some(logical_scan) = right.as_logical_scan() {
1000 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1001 } else {
1002 false
1003 }
1004 }
1005
1006 fn to_stream_temporal_join_with_index_selection(
1007 &self,
1008 predicate: EqJoinPredicate,
1009 ctx: &mut ToStreamContext,
1010 ) -> Result<PlanRef> {
1011 let right = self.right();
1013 let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1015
1016 let mut result_plan: Result<StreamTemporalJoin> =
1018 self.to_stream_temporal_join(predicate.clone(), ctx);
1019 if let Ok(temporal_join) = &result_plan
1021 && temporal_join.eq_join_predicate().eq_indexes().len()
1022 == logical_scan.primary_key().len()
1023 {
1024 return result_plan.map(|x| x.into());
1025 }
1026 let indexes = logical_scan.indexes();
1027 for index in indexes {
1028 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1030 let index_scan: PlanRef = index_scan.into();
1031 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1032 if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1033 match &result_plan {
1034 Err(_) => result_plan = Ok(temporal_join),
1035 Ok(prev_temporal_join) => {
1036 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1038 < temporal_join.eq_join_predicate().eq_indexes().len()
1039 {
1040 result_plan = Ok(temporal_join)
1041 }
1042 }
1043 }
1044 }
1045 }
1046 }
1047
1048 result_plan.map(|x| x.into())
1049 }
1050
1051 fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1052 let Some(logical_scan) = right.as_logical_scan() else {
1053 return Err(RwError::from(ErrorCode::NotSupported(
1054 "Temporal join requires a table scan as its lookup table".into(),
1055 "Please provide a table scan".into(),
1056 )));
1057 };
1058
1059 if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1060 return Err(RwError::from(ErrorCode::NotSupported(
1061 "Temporal join requires a table defined as temporal table".into(),
1062 "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1063 )));
1064 }
1065 Ok(logical_scan)
1066 }
1067
1068 fn temporal_join_scan_predicate_pull_up(
1069 logical_scan: &LogicalScan,
1070 predicate: EqJoinPredicate,
1071 output_indices: &[usize],
1072 left_schema_len: usize,
1073 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1074 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1076 let o2r = if let Some(project_expr) = project_expr {
1078 project_expr
1079 .into_iter()
1080 .map(|x| x.as_input_ref().unwrap().index)
1081 .collect_vec()
1082 } else {
1083 (0..logical_scan.output_col_idx().len()).collect_vec()
1084 };
1085 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1086 offset: left_schema_len,
1087 mapping: o2r.clone(),
1088 };
1089
1090 let new_eq_cond = predicate
1091 .eq_cond()
1092 .rewrite_expr(&mut join_predicate_rewriter);
1093
1094 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1095 offset: left_schema_len,
1096 };
1097
1098 let new_other_cond = predicate
1099 .other_cond()
1100 .clone()
1101 .rewrite_expr(&mut join_predicate_rewriter)
1102 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1103
1104 let new_join_on = new_eq_cond.and(new_other_cond);
1105
1106 let new_predicate = EqJoinPredicate::create(
1107 left_schema_len,
1108 new_scan.schema().len(),
1109 new_join_on.clone(),
1110 );
1111
1112 let new_join_output_indices = output_indices
1115 .iter()
1116 .map(|&x| {
1117 if x < left_schema_len {
1118 x
1119 } else {
1120 o2r[x - left_schema_len] + left_schema_len
1121 }
1122 })
1123 .collect_vec();
1124 let new_stream_table_scan =
1126 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1127 Ok((
1128 new_stream_table_scan,
1129 new_predicate,
1130 new_join_on,
1131 new_join_output_indices,
1132 ))
1133 }
1134
1135 fn to_stream_temporal_join(
1136 &self,
1137 predicate: EqJoinPredicate,
1138 ctx: &mut ToStreamContext,
1139 ) -> Result<StreamTemporalJoin> {
1140 use super::stream::prelude::*;
1141
1142 assert!(predicate.has_eq());
1143
1144 let right = self.right();
1145
1146 let logical_scan = Self::check_temporal_rhs(&right)?;
1147
1148 let table_desc = logical_scan.table_desc();
1149 let output_column_ids = logical_scan.output_column_ids();
1150
1151 let order_col_ids = table_desc.order_column_ids();
1154 let order_key = table_desc.order_column_indices();
1155 let dist_key = table_desc.distribution_key.clone();
1156
1157 let mut dist_key_in_order_key_pos = vec![];
1158 for d in dist_key {
1159 let pos = order_key
1160 .iter()
1161 .position(|&x| x == d)
1162 .expect("dist_key must in order_key");
1163 dist_key_in_order_key_pos.push(pos);
1164 }
1165 let shortest_prefix_len = dist_key_in_order_key_pos
1167 .iter()
1168 .max()
1169 .map_or(0, |pos| pos + 1);
1170
1171 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1173 for order_col_id in order_col_ids {
1174 let mut found = false;
1175 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1176 if order_col_id == output_column_ids[eq_idx] {
1177 reorder_idx.push(i);
1178 found = true;
1179 break;
1180 }
1181 }
1182 if !found {
1183 break;
1184 }
1185 }
1186 if reorder_idx.len() < shortest_prefix_len {
1187 return Err(RwError::from(ErrorCode::NotSupported(
1189 "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1190 "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1191 )));
1192 }
1193 let lookup_prefix_len = reorder_idx.len();
1194 let predicate = predicate.reorder(&reorder_idx);
1195
1196 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1197 RequiredDist::single()
1198 } else {
1199 let left_eq_indexes = predicate.left_eq_indexes();
1200 let left_dist_key = dist_key_in_order_key_pos
1201 .iter()
1202 .map(|pos| left_eq_indexes[*pos])
1203 .collect_vec();
1204
1205 RequiredDist::hash_shard(&left_dist_key)
1206 };
1207
1208 let left = self.left().to_stream(ctx)?;
1209 let left = required_dist.enforce(left, &Order::any());
1211
1212 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1213 Self::temporal_join_scan_predicate_pull_up(
1214 logical_scan,
1215 predicate,
1216 self.output_indices(),
1217 self.left().schema().len(),
1218 )?;
1219
1220 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1221 if !new_predicate.has_eq() {
1222 return Err(RwError::from(ErrorCode::NotSupported(
1223 "Temporal join requires a non trivial join condition".into(),
1224 "Please remove the false condition of the join".into(),
1225 )));
1226 }
1227
1228 let new_logical_join = generic::Join::new(
1230 left,
1231 right,
1232 new_join_on,
1233 self.join_type(),
1234 new_join_output_indices,
1235 );
1236
1237 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1238
1239 Ok(StreamTemporalJoin::new(
1240 new_logical_join,
1241 new_predicate,
1242 false,
1243 ))
1244 }
1245
1246 fn to_stream_nested_loop_temporal_join(
1247 &self,
1248 predicate: EqJoinPredicate,
1249 ctx: &mut ToStreamContext,
1250 ) -> Result<PlanRef> {
1251 use super::stream::prelude::*;
1252 assert!(!predicate.has_eq());
1253
1254 let left = self.left().to_stream_with_dist_required(
1255 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1256 ctx,
1257 )?;
1258 assert!(left.as_stream_exchange().is_some());
1259
1260 if self.join_type() != JoinType::Inner {
1261 return Err(RwError::from(ErrorCode::NotSupported(
1262 "Temporal join requires an inner join".into(),
1263 "Please use an inner join".into(),
1264 )));
1265 }
1266
1267 if !left.append_only() {
1268 return Err(RwError::from(ErrorCode::NotSupported(
1269 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1270 "Please ensure the left hash side is append only".into(),
1271 )));
1272 }
1273
1274 let right = self.right();
1275 let logical_scan = Self::check_temporal_rhs(&right)?;
1276
1277 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1278 Self::temporal_join_scan_predicate_pull_up(
1279 logical_scan,
1280 predicate,
1281 self.output_indices(),
1282 self.left().schema().len(),
1283 )?;
1284
1285 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1286
1287 let new_logical_join = generic::Join::new(
1289 left,
1290 right,
1291 new_join_on,
1292 self.join_type(),
1293 new_join_output_indices,
1294 );
1295
1296 Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true).into())
1297 }
1298
1299 fn to_stream_dynamic_filter(
1300 &self,
1301 predicate: Condition,
1302 ctx: &mut ToStreamContext,
1303 ) -> Result<Option<PlanRef>> {
1304 use super::stream::prelude::*;
1305
1306 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1312 return Ok(None);
1313 }
1314
1315 if !self.right().max_one_row() {
1317 return Ok(None);
1318 }
1319 if self.right().schema().len() != 1 {
1320 return Ok(None);
1321 }
1322
1323 if predicate.conjunctions.len() > 1 {
1325 return Ok(None);
1326 }
1327 let expr: ExprImpl = predicate.into();
1328 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1329 Some(v) => v,
1330 None => return Ok(None),
1331 };
1332
1333 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1334 && right_ref.index == self.left().schema().len() ;
1335 if !condition_cross_inputs {
1336 return Ok(None);
1338 }
1339
1340 if self.left().schema().fields()[left_ref.index].data_type
1342 != self.right().schema().fields()[0].data_type
1343 {
1344 return Ok(None);
1345 }
1346
1347 let all_output_from_left = self
1349 .output_indices()
1350 .iter()
1351 .all(|i| *i < self.left().schema().len());
1352 if !all_output_from_left {
1353 return Ok(None);
1354 }
1355
1356 let left = self.left().to_stream(ctx)?;
1357 let right = self.right().to_stream_with_dist_required(
1358 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1359 ctx,
1360 )?;
1361
1362 assert!(right.as_stream_exchange().is_some());
1363 assert_eq!(
1364 *right.inputs().iter().exactly_one().unwrap().distribution(),
1365 Distribution::Single
1366 );
1367
1368 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1369 let plan = StreamDynamicFilter::new(core).into();
1370 if self
1372 .output_indices()
1373 .iter()
1374 .copied()
1375 .ne(0..self.left().schema().len())
1376 {
1377 let logical_project = generic::Project::with_mapping(
1380 plan,
1381 ColIndexMapping::with_remaining_columns(
1382 self.output_indices(),
1383 self.left().schema().len(),
1384 ),
1385 );
1386 Ok(Some(StreamProject::new(logical_project).into()))
1387 } else {
1388 Ok(Some(plan))
1389 }
1390 }
1391
1392 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
1393 let predicate = EqJoinPredicate::create(
1394 self.left().schema().len(),
1395 self.right().schema().len(),
1396 self.on().clone(),
1397 );
1398 assert!(predicate.has_eq());
1399
1400 let mut logical_join = self.core.clone();
1401 logical_join.left = logical_join.left.to_batch()?;
1402 logical_join.right = logical_join.right.to_batch()?;
1403
1404 Ok(self
1405 .to_batch_lookup_join(predicate, logical_join)?
1406 .expect("Fail to convert to lookup join")
1407 .into())
1408 }
1409
1410 fn to_stream_asof_join(
1411 &self,
1412 predicate: EqJoinPredicate,
1413 ctx: &mut ToStreamContext,
1414 ) -> Result<PlanRef> {
1415 use super::stream::prelude::*;
1416
1417 if predicate.eq_keys().is_empty() {
1418 return Err(ErrorCode::InvalidInputSyntax(
1419 "AsOf join requires at least 1 equal condition".to_owned(),
1420 )
1421 .into());
1422 }
1423
1424 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1425 let left_len = left.schema().len();
1426 let logical_join = self.clone_with_left_right(left, right);
1427
1428 let inequality_desc =
1429 Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1430
1431 Ok(StreamAsOfJoin::new(logical_join.core.clone(), predicate, inequality_desc).into())
1432 }
1433
1434 fn to_batch_hash_join(
1436 &self,
1437 logical_join: generic::Join<PlanRef>,
1438 predicate: EqJoinPredicate,
1439 ) -> Result<PlanRef> {
1440 use super::batch::prelude::*;
1441
1442 let left_schema_len = logical_join.left.schema().len();
1443 let asof_desc = self
1444 .is_asof_join()
1445 .then(|| {
1446 Self::get_inequality_desc_from_predicate(
1447 predicate.other_cond().clone(),
1448 left_schema_len,
1449 )
1450 })
1451 .transpose()?;
1452
1453 let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1454 Ok(batch_join.into())
1455 }
1456
1457 pub fn get_inequality_desc_from_predicate(
1458 predicate: Condition,
1459 left_input_len: usize,
1460 ) -> Result<AsOfJoinDesc> {
1461 let expr: ExprImpl = predicate.into();
1462 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1463 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1464 {
1465 Ok(AsOfJoinDesc {
1466 left_idx: left_input_ref.index() as u32,
1467 right_idx: (right_input_ref.index() - left_input_len) as u32,
1468 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1469 })
1470 } else {
1471 bail!("inequal condition from the same side should be push down in optimizer");
1472 }
1473 } else {
1474 Err(ErrorCode::InvalidInputSyntax(
1475 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1476 )
1477 .into())
1478 }
1479 }
1480
1481 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1482 match expr_type {
1483 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1484 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1485 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1486 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1487 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1488 "Invalid comparison type: {}",
1489 expr_type.as_str_name()
1490 ))
1491 .into()),
1492 }
1493 }
1494}
1495
1496impl ToBatch for LogicalJoin {
1497 fn to_batch(&self) -> Result<PlanRef> {
1498 let predicate = EqJoinPredicate::create(
1499 self.left().schema().len(),
1500 self.right().schema().len(),
1501 self.on().clone(),
1502 );
1503
1504 let mut logical_join = self.core.clone();
1505 logical_join.left = logical_join.left.to_batch()?;
1506 logical_join.right = logical_join.right.to_batch()?;
1507
1508 let ctx = self.base.ctx();
1509 let config = ctx.session_ctx().config();
1510
1511 if predicate.has_eq() {
1512 if !predicate.eq_keys_are_type_aligned() {
1513 return Err(ErrorCode::InternalError(format!(
1514 "Join eq keys are not aligned for predicate: {predicate:?}"
1515 ))
1516 .into());
1517 }
1518 if config.batch_enable_lookup_join()
1519 && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1520 predicate.clone(),
1521 logical_join.clone(),
1522 )?
1523 {
1524 return Ok(lookup_join.into());
1525 }
1526 self.to_batch_hash_join(logical_join, predicate)
1527 } else if self.is_asof_join() {
1528 Err(ErrorCode::InvalidInputSyntax(
1529 "AsOf join requires at least 1 equal condition".to_owned(),
1530 )
1531 .into())
1532 } else {
1533 Ok(BatchNestedLoopJoin::new(logical_join).into())
1535 }
1536 }
1537}
1538
1539impl ToStream for LogicalJoin {
1540 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1541 if self
1542 .on()
1543 .conjunctions
1544 .iter()
1545 .any(|cond| cond.count_nows() > 0)
1546 {
1547 return Err(ErrorCode::NotSupported(
1548 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1549 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1550 }
1551
1552 let predicate = EqJoinPredicate::create(
1553 self.left().schema().len(),
1554 self.right().schema().len(),
1555 self.on().clone(),
1556 );
1557
1558 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1559 self.to_stream_asof_join(predicate, ctx)
1560 } else if predicate.has_eq() {
1561 if !predicate.eq_keys_are_type_aligned() {
1562 return Err(ErrorCode::InternalError(format!(
1563 "Join eq keys are not aligned for predicate: {predicate:?}"
1564 ))
1565 .into());
1566 }
1567
1568 if self.should_be_temporal_join() {
1569 self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1570 } else {
1571 self.to_stream_hash_join(predicate, ctx)
1572 }
1573 } else if self.should_be_temporal_join() {
1574 self.to_stream_nested_loop_temporal_join(predicate, ctx)
1575 } else if let Some(dynamic_filter) =
1576 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1577 {
1578 Ok(dynamic_filter)
1579 } else {
1580 Err(RwError::from(ErrorCode::NotSupported(
1581 "streaming nested-loop join".to_owned(),
1582 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1583 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1584 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1585 )))
1586 }
1587 }
1588
1589 fn logical_rewrite_for_stream(
1590 &self,
1591 ctx: &mut RewriteStreamContext,
1592 ) -> Result<(PlanRef, ColIndexMapping)> {
1593 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1594 let left_len = left.schema().len();
1595 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1596 let (join, out_col_change) = self.rewrite_with_left_right(
1597 left.clone(),
1598 left_col_change,
1599 right.clone(),
1600 right_col_change,
1601 );
1602
1603 let mapping = ColIndexMapping::with_remaining_columns(
1604 join.output_indices(),
1605 join.internal_column_num(),
1606 );
1607
1608 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1609 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1610
1611 let mut left_to_add = left
1613 .expect_stream_key()
1614 .iter()
1615 .cloned()
1616 .filter(|i| l2o.try_map(*i).is_none())
1617 .collect_vec();
1618
1619 let mut right_to_add = right
1620 .expect_stream_key()
1621 .iter()
1622 .filter(|&&i| r2o.try_map(i).is_none())
1623 .map(|&i| i + left_len)
1624 .collect_vec();
1625
1626 let right_len = right.schema().len();
1629 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1630
1631 let either_or_both = self.core.add_which_join_key_to_pk();
1632
1633 for (lk, rk) in eq_predicate.eq_indexes() {
1634 match either_or_both {
1635 EitherOrBoth::Left(_) => {
1636 if l2o.try_map(lk).is_none() {
1637 left_to_add.push(lk);
1638 }
1639 }
1640 EitherOrBoth::Right(_) => {
1641 if r2o.try_map(rk).is_none() {
1642 right_to_add.push(rk + left_len)
1643 }
1644 }
1645 EitherOrBoth::Both(_, _) => {
1646 if l2o.try_map(lk).is_none() {
1647 left_to_add.push(lk);
1648 }
1649 if r2o.try_map(rk).is_none() {
1650 right_to_add.push(rk + left_len)
1651 }
1652 }
1653 };
1654 }
1655 let left_to_add = left_to_add.into_iter().unique();
1656 let right_to_add = right_to_add.into_iter().unique();
1657 let mut new_output_indices = join.output_indices().clone();
1660 if !join.is_right_join() {
1661 new_output_indices.extend(left_to_add);
1662 }
1663 if !join.is_left_join() {
1664 new_output_indices.extend(right_to_add);
1665 }
1666
1667 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1668
1669 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1670 let l2o = join_with_pk
1673 .core
1674 .l2i_col_mapping()
1675 .composite(&join_with_pk.core.i2o_col_mapping());
1676 let r2o = join_with_pk
1677 .core
1678 .r2i_col_mapping()
1679 .composite(&join_with_pk.core.i2o_col_mapping());
1680 let left_right_stream_keys = join_with_pk
1681 .left()
1682 .expect_stream_key()
1683 .iter()
1684 .map(|i| l2o.map(*i))
1685 .chain(
1686 join_with_pk
1687 .right()
1688 .expect_stream_key()
1689 .iter()
1690 .map(|i| r2o.map(*i)),
1691 )
1692 .collect_vec();
1693 let plan: PlanRef = join_with_pk.into();
1694 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1695 } else {
1696 join_with_pk.into()
1697 };
1698
1699 Ok((plan, out_col_change))
1701 }
1702}
1703
1704#[cfg(test)]
1705mod tests {
1706
1707 use std::collections::HashSet;
1708
1709 use risingwave_common::catalog::{Field, Schema};
1710 use risingwave_common::types::{DataType, Datum};
1711 use risingwave_pb::expr::expr_node::Type;
1712
1713 use super::*;
1714 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1715 use crate::optimizer::optimizer_context::OptimizerContext;
1716 use crate::optimizer::plan_node::LogicalValues;
1717 use crate::optimizer::property::FunctionalDependency;
1718
1719 #[tokio::test]
1733 async fn test_prune_join() {
1734 let ty = DataType::Int32;
1735 let ctx = OptimizerContext::mock().await;
1736 let fields: Vec<Field> = (1..7)
1737 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1738 .collect();
1739 let left = LogicalValues::new(
1740 vec![],
1741 Schema {
1742 fields: fields[0..3].to_vec(),
1743 },
1744 ctx.clone(),
1745 );
1746 let right = LogicalValues::new(
1747 vec![],
1748 Schema {
1749 fields: fields[3..6].to_vec(),
1750 },
1751 ctx,
1752 );
1753 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1754 FunctionCall::new(
1755 Type::Equal,
1756 vec![
1757 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1758 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1759 ],
1760 )
1761 .unwrap(),
1762 ));
1763 let join_type = JoinType::Inner;
1764 let join: PlanRef = LogicalJoin::new(
1765 left.into(),
1766 right.into(),
1767 join_type,
1768 Condition::with_expr(on),
1769 )
1770 .into();
1771
1772 let required_cols = vec![2, 3];
1774 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1775
1776 let join = plan.as_logical_join().unwrap();
1778 assert_eq!(join.schema().fields().len(), 2);
1779 assert_eq!(join.schema().fields()[0], fields[2]);
1780 assert_eq!(join.schema().fields()[1], fields[3]);
1781
1782 let expr: ExprImpl = join.on().clone().into();
1783 let call = expr.as_function_call().unwrap();
1784 assert_eq_input_ref!(&call.inputs()[0], 0);
1785 assert_eq_input_ref!(&call.inputs()[1], 2);
1786
1787 let left = join.left();
1788 let left = left.as_logical_values().unwrap();
1789 assert_eq!(left.schema().fields(), &fields[1..3]);
1790 let right = join.right();
1791 let right = right.as_logical_values().unwrap();
1792 assert_eq!(right.schema().fields(), &fields[3..4]);
1793 }
1794
1795 #[tokio::test]
1797 async fn test_prune_semi_join() {
1798 let ty = DataType::Int32;
1799 let ctx = OptimizerContext::mock().await;
1800 let fields: Vec<Field> = (1..7)
1801 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1802 .collect();
1803 let left = LogicalValues::new(
1804 vec![],
1805 Schema {
1806 fields: fields[0..3].to_vec(),
1807 },
1808 ctx.clone(),
1809 );
1810 let right = LogicalValues::new(
1811 vec![],
1812 Schema {
1813 fields: fields[3..6].to_vec(),
1814 },
1815 ctx,
1816 );
1817 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1818 FunctionCall::new(
1819 Type::Equal,
1820 vec![
1821 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1822 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1823 ],
1824 )
1825 .unwrap(),
1826 ));
1827 for join_type in [
1828 JoinType::LeftSemi,
1829 JoinType::RightSemi,
1830 JoinType::LeftAnti,
1831 JoinType::RightAnti,
1832 ] {
1833 let join = LogicalJoin::new(
1834 left.clone().into(),
1835 right.clone().into(),
1836 join_type,
1837 Condition::with_expr(on.clone()),
1838 );
1839
1840 let offset = if join.is_right_join() { 3 } else { 0 };
1841 let join: PlanRef = join.into();
1842 let required_cols = vec![0];
1844 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1846 let as_plan = plan.as_logical_join().unwrap();
1847 assert_eq!(as_plan.schema().fields().len(), 1);
1849 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1850
1851 let required_cols = vec![0, 1, 2];
1853 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1855 let as_plan = plan.as_logical_join().unwrap();
1856 assert_eq!(as_plan.schema().fields().len(), 3);
1858 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1859 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1860 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1861 }
1862 }
1863
1864 #[tokio::test]
1877 async fn test_prune_join_no_project() {
1878 let ty = DataType::Int32;
1879 let ctx = OptimizerContext::mock().await;
1880 let fields: Vec<Field> = (1..7)
1881 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1882 .collect();
1883 let left = LogicalValues::new(
1884 vec![],
1885 Schema {
1886 fields: fields[0..3].to_vec(),
1887 },
1888 ctx.clone(),
1889 );
1890 let right = LogicalValues::new(
1891 vec![],
1892 Schema {
1893 fields: fields[3..6].to_vec(),
1894 },
1895 ctx,
1896 );
1897 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1898 FunctionCall::new(
1899 Type::Equal,
1900 vec![
1901 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1902 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1903 ],
1904 )
1905 .unwrap(),
1906 ));
1907 let join_type = JoinType::Inner;
1908 let join: PlanRef = LogicalJoin::new(
1909 left.into(),
1910 right.into(),
1911 join_type,
1912 Condition::with_expr(on),
1913 )
1914 .into();
1915
1916 let required_cols = vec![1, 3];
1918 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1919
1920 let join = plan.as_logical_join().unwrap();
1922 assert_eq!(join.schema().fields().len(), 2);
1923 assert_eq!(join.schema().fields()[0], fields[1]);
1924 assert_eq!(join.schema().fields()[1], fields[3]);
1925
1926 let expr: ExprImpl = join.on().clone().into();
1927 let call = expr.as_function_call().unwrap();
1928 assert_eq_input_ref!(&call.inputs()[0], 0);
1929 assert_eq_input_ref!(&call.inputs()[1], 1);
1930
1931 let left = join.left();
1932 let left = left.as_logical_values().unwrap();
1933 assert_eq!(left.schema().fields(), &fields[1..2]);
1934 let right = join.right();
1935 let right = right.as_logical_values().unwrap();
1936 assert_eq!(right.schema().fields(), &fields[3..4]);
1937 }
1938
1939 #[tokio::test]
1953 async fn test_join_to_batch() {
1954 let ctx = OptimizerContext::mock().await;
1955 let fields: Vec<Field> = (1..7)
1956 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1957 .collect();
1958 let left = LogicalValues::new(
1959 vec![],
1960 Schema {
1961 fields: fields[0..3].to_vec(),
1962 },
1963 ctx.clone(),
1964 );
1965 let right = LogicalValues::new(
1966 vec![],
1967 Schema {
1968 fields: fields[3..6].to_vec(),
1969 },
1970 ctx,
1971 );
1972
1973 fn input_ref(i: usize) -> ExprImpl {
1974 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1975 }
1976 let eq_cond = ExprImpl::FunctionCall(Box::new(
1977 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1978 ));
1979 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1980 FunctionCall::new(
1981 Type::Equal,
1982 vec![
1983 input_ref(2),
1984 ExprImpl::Literal(Box::new(Literal::new(
1985 Datum::Some(42_i32.into()),
1986 DataType::Int32,
1987 ))),
1988 ],
1989 )
1990 .unwrap(),
1991 ));
1992 let on_cond = ExprImpl::FunctionCall(Box::new(
1994 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
1995 ));
1996
1997 let join_type = JoinType::Inner;
1998 let logical_join = LogicalJoin::new(
1999 left.into(),
2000 right.into(),
2001 join_type,
2002 Condition::with_expr(on_cond),
2003 );
2004
2005 let result = logical_join.to_batch().unwrap();
2007
2008 let hash_join = result.as_batch_hash_join().unwrap();
2010 assert_eq!(
2011 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2012 eq_cond
2013 );
2014 assert_eq!(
2015 *hash_join
2016 .eq_join_predicate()
2017 .non_eq_cond()
2018 .conjunctions
2019 .first()
2020 .unwrap(),
2021 non_eq_cond
2022 );
2023 }
2024
2025 #[tokio::test]
2038 #[ignore] async fn test_join_to_stream() {
2041 }
2109 #[tokio::test]
2123 async fn test_join_column_prune_with_order_required() {
2124 let ty = DataType::Int32;
2125 let ctx = OptimizerContext::mock().await;
2126 let fields: Vec<Field> = (1..7)
2127 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2128 .collect();
2129 let left = LogicalValues::new(
2130 vec![],
2131 Schema {
2132 fields: fields[0..3].to_vec(),
2133 },
2134 ctx.clone(),
2135 );
2136 let right = LogicalValues::new(
2137 vec![],
2138 Schema {
2139 fields: fields[3..6].to_vec(),
2140 },
2141 ctx,
2142 );
2143 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2144 FunctionCall::new(
2145 Type::Equal,
2146 vec![
2147 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2148 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2149 ],
2150 )
2151 .unwrap(),
2152 ));
2153 let join_type = JoinType::Inner;
2154 let join: PlanRef = LogicalJoin::new(
2155 left.into(),
2156 right.into(),
2157 join_type,
2158 Condition::with_expr(on),
2159 )
2160 .into();
2161
2162 let required_cols = vec![3, 2];
2164 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2165
2166 let join = plan.as_logical_join().unwrap();
2168 assert_eq!(join.schema().fields().len(), 2);
2169 assert_eq!(join.schema().fields()[0], fields[3]);
2170 assert_eq!(join.schema().fields()[1], fields[2]);
2171
2172 let expr: ExprImpl = join.on().clone().into();
2173 let call = expr.as_function_call().unwrap();
2174 assert_eq_input_ref!(&call.inputs()[0], 0);
2175 assert_eq_input_ref!(&call.inputs()[1], 2);
2176
2177 let left = join.left();
2178 let left = left.as_logical_values().unwrap();
2179 assert_eq!(left.schema().fields(), &fields[1..3]);
2180 let right = join.right();
2181 let right = right.as_logical_values().unwrap();
2182 assert_eq!(right.schema().fields(), &fields[3..4]);
2183 }
2184
2185 #[tokio::test]
2186 async fn fd_derivation_inner_outer_join() {
2187 let ctx = OptimizerContext::mock().await;
2210 let left = {
2211 let fields: Vec<Field> = vec![
2212 Field::with_name(DataType::Int32, "l0"),
2213 Field::with_name(DataType::Int32, "l1"),
2214 ];
2215 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2216 values
2218 .base
2219 .functional_dependency_mut()
2220 .add_functional_dependency_by_column_indices(&[0], &[1]);
2221 values
2222 };
2223 let right = {
2224 let fields: Vec<Field> = vec![
2225 Field::with_name(DataType::Int32, "r0"),
2226 Field::with_name(DataType::Int32, "r1"),
2227 Field::with_name(DataType::Int32, "r2"),
2228 ];
2229 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2230 values
2232 .base
2233 .functional_dependency_mut()
2234 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2235 values
2236 };
2237 let on: ExprImpl = FunctionCall::new(
2239 Type::And,
2240 vec![
2241 FunctionCall::new(
2242 Type::Equal,
2243 vec![
2244 InputRef::new(0, DataType::Int32).into(),
2245 ExprImpl::literal_int(0),
2246 ],
2247 )
2248 .unwrap()
2249 .into(),
2250 FunctionCall::new(
2251 Type::Equal,
2252 vec![
2253 InputRef::new(1, DataType::Int32).into(),
2254 InputRef::new(3, DataType::Int32).into(),
2255 ],
2256 )
2257 .unwrap()
2258 .into(),
2259 ],
2260 )
2261 .unwrap()
2262 .into();
2263 let expected_fd_set = [
2264 (
2265 JoinType::Inner,
2266 [
2267 FunctionalDependency::with_indices(5, &[0], &[1]),
2269 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2271 FunctionalDependency::with_indices(5, &[], &[0]),
2273 FunctionalDependency::with_indices(5, &[1], &[3]),
2275 FunctionalDependency::with_indices(5, &[3], &[1]),
2276 ]
2277 .into_iter()
2278 .collect::<HashSet<_>>(),
2279 ),
2280 (JoinType::FullOuter, HashSet::new()),
2281 (
2282 JoinType::RightOuter,
2283 [
2284 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2286 ]
2287 .into_iter()
2288 .collect::<HashSet<_>>(),
2289 ),
2290 (
2291 JoinType::LeftOuter,
2292 [
2293 FunctionalDependency::with_indices(5, &[0], &[1]),
2295 ]
2296 .into_iter()
2297 .collect::<HashSet<_>>(),
2298 ),
2299 (
2300 JoinType::LeftSemi,
2301 [
2302 FunctionalDependency::with_indices(2, &[0], &[1]),
2304 ]
2305 .into_iter()
2306 .collect::<HashSet<_>>(),
2307 ),
2308 (
2309 JoinType::LeftAnti,
2310 [
2311 FunctionalDependency::with_indices(2, &[0], &[1]),
2313 ]
2314 .into_iter()
2315 .collect::<HashSet<_>>(),
2316 ),
2317 (
2318 JoinType::RightSemi,
2319 [
2320 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2322 ]
2323 .into_iter()
2324 .collect::<HashSet<_>>(),
2325 ),
2326 (
2327 JoinType::RightAnti,
2328 [
2329 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2331 ]
2332 .into_iter()
2333 .collect::<HashSet<_>>(),
2334 ),
2335 ];
2336
2337 for (join_type, expected_res) in expected_fd_set {
2338 let join = LogicalJoin::new(
2339 left.clone().into(),
2340 right.clone().into(),
2341 join_type,
2342 Condition::with_expr(on.clone()),
2343 );
2344 let fd_set = join
2345 .functional_dependency()
2346 .as_dependencies()
2347 .iter()
2348 .cloned()
2349 .collect::<HashSet<_>>();
2350 assert_eq!(fd_set, expected_res);
2351 }
2352 }
2353}