1use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31 BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
32 PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject, ToBatch,
33 ToStream, generic,
34};
35use crate::error::{ErrorCode, Result, RwError};
36use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
37use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
38use crate::optimizer::plan_node::generic::DynamicFilter;
39use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
40use crate::optimizer::plan_node::utils::IndicesDisplay;
41use crate::optimizer::plan_node::{
42 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
43 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
44 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
45};
46use crate::optimizer::plan_visitor::LogicalCardinalityExt;
47use crate::optimizer::property::{Distribution, RequiredDist};
48use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct LogicalJoin {
58 pub base: PlanBase<Logical>,
59 core: generic::Join<PlanRef>,
60}
61
62impl Distill for LogicalJoin {
63 fn distill<'a>(&self) -> XmlNode<'a> {
64 let verbose = self.base.ctx().is_explain_verbose();
65 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
66 vec.push(("type", Pretty::debug(&self.join_type())));
67
68 let concat_schema = self.core.concat_schema();
69 let cond = Pretty::debug(&ConditionDisplay {
70 condition: self.on(),
71 input_schema: &concat_schema,
72 });
73 vec.push(("on", cond));
74
75 if verbose {
76 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
77 vec.push(("output", data));
78 }
79
80 childless_record("LogicalJoin", vec)
81 }
82}
83
84impl LogicalJoin {
85 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
86 let core = generic::Join::with_full_output(left, right, join_type, on);
87 Self::with_core(core)
88 }
89
90 pub(crate) fn with_output_indices(
91 left: PlanRef,
92 right: PlanRef,
93 join_type: JoinType,
94 on: Condition,
95 output_indices: Vec<usize>,
96 ) -> Self {
97 let core = generic::Join::new(left, right, on, join_type, output_indices);
98 Self::with_core(core)
99 }
100
101 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
102 let base = PlanBase::new_logical_with_core(&core);
103 LogicalJoin { base, core }
104 }
105
106 pub fn create(
107 left: PlanRef,
108 right: PlanRef,
109 join_type: JoinType,
110 on_clause: ExprImpl,
111 ) -> PlanRef {
112 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
113 }
114
115 pub fn internal_column_num(&self) -> usize {
116 self.core.internal_column_num()
117 }
118
119 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
120 self.core.i2l_col_mapping_ignore_join_type()
121 }
122
123 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
124 self.core.i2r_col_mapping_ignore_join_type()
125 }
126
127 pub fn on(&self) -> &Condition {
129 &self.core.on
130 }
131
132 pub fn core(&self) -> &generic::Join<PlanRef> {
133 &self.core
134 }
135
136 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
138 let input_refs = self
139 .core
140 .on
141 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
142 let index_group = input_refs
143 .ones()
144 .chunk_by(|i| *i < self.core.left.schema().len());
145 let left_index = index_group
146 .into_iter()
147 .next()
148 .map_or(vec![], |group| group.1.collect_vec());
149 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
150 group
151 .1
152 .map(|i| i - self.core.left.schema().len())
153 .collect_vec()
154 });
155 (left_index, right_index)
156 }
157
158 pub fn join_type(&self) -> JoinType {
160 self.core.join_type
161 }
162
163 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
165 self.core.eq_indexes()
166 }
167
168 pub fn output_indices(&self) -> &Vec<usize> {
170 &self.core.output_indices
171 }
172
173 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
175 Self::with_core(generic::Join {
176 output_indices,
177 ..self.core.clone()
178 })
179 }
180
181 pub fn clone_with_cond(&self, on: Condition) -> Self {
183 Self::with_core(generic::Join {
184 on,
185 ..self.core.clone()
186 })
187 }
188
189 pub fn is_left_join(&self) -> bool {
190 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
191 }
192
193 pub fn is_right_join(&self) -> bool {
194 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
195 }
196
197 pub fn is_full_out(&self) -> bool {
198 self.core.is_full_out()
199 }
200
201 pub fn is_asof_join(&self) -> bool {
202 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
203 }
204
205 pub fn output_indices_are_trivial(&self) -> bool {
206 self.output_indices() == &(0..self.internal_column_num()).collect_vec()
207 }
208
209 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
214 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
215 JoinType::LeftOuter => (false, true),
216 JoinType::RightOuter => (true, false),
217 JoinType::FullOuter => (true, true),
218 _ => return join_type,
219 };
220
221 for expr in &predicate.conjunctions {
222 if let ExprImpl::FunctionCall(func) = expr {
223 match func.func_type() {
224 ExprType::Equal
225 | ExprType::NotEqual
226 | ExprType::LessThan
227 | ExprType::LessThanOrEqual
228 | ExprType::GreaterThan
229 | ExprType::GreaterThanOrEqual => {
230 for input in func.inputs() {
231 if let ExprImpl::InputRef(input) = input {
232 let idx = input.index;
233 if idx < left_col_num {
234 gen_null_in_left = false;
235 } else {
236 gen_null_in_right = false;
237 }
238 }
239 }
240 }
241 _ => {}
242 };
243 }
244 }
245
246 match (gen_null_in_left, gen_null_in_right) {
247 (true, true) => JoinType::FullOuter,
248 (true, false) => JoinType::RightOuter,
249 (false, true) => JoinType::LeftOuter,
250 (false, false) => JoinType::Inner,
251 }
252 }
253
254 fn to_batch_lookup_join_with_index_selection(
258 &self,
259 predicate: EqJoinPredicate,
260 batch_join: generic::Join<BatchPlanRef>,
261 ) -> Result<Option<BatchLookupJoin>> {
262 match batch_join.join_type {
263 JoinType::Inner
264 | JoinType::LeftOuter
265 | JoinType::LeftSemi
266 | JoinType::LeftAnti
267 | JoinType::AsofInner
268 | JoinType::AsofLeftOuter => {}
269 _ => return Ok(None),
270 };
271
272 let right = self.right();
274 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
276 logical_scan
277 } else {
278 return Ok(None);
279 };
280
281 let mut result_plan = None;
282 if let Some(lookup_join) =
284 self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
285 {
286 result_plan = Some(lookup_join);
287 }
288
289 let indexes = logical_scan.table_indexes();
290 for index in indexes {
291 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
292 let index_scan: PlanRef = index_scan.into();
293 let that = self.clone_with_left_right(self.left(), index_scan.clone());
294 let mut new_batch_join = batch_join.clone();
295 new_batch_join.right = index_scan.to_batch().expect("index scan failed to batch");
296
297 if let Some(lookup_join) =
299 that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
300 {
301 match &result_plan {
302 None => result_plan = Some(lookup_join),
303 Some(prev_lookup_join) => {
304 if prev_lookup_join.lookup_prefix_len()
306 < lookup_join.lookup_prefix_len()
307 {
308 result_plan = Some(lookup_join)
309 }
310 }
311 }
312 }
313 }
314 }
315
316 Ok(result_plan)
317 }
318
319 fn to_batch_lookup_join(
321 &self,
322 predicate: EqJoinPredicate,
323 logical_join: generic::Join<BatchPlanRef>,
324 ) -> Result<Option<BatchLookupJoin>> {
325 match logical_join.join_type {
326 JoinType::Inner
327 | JoinType::LeftOuter
328 | JoinType::LeftSemi
329 | JoinType::LeftAnti
330 | JoinType::AsofInner
331 | JoinType::AsofLeftOuter => {}
332 _ => return Ok(None),
333 };
334
335 let right = self.right();
336 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
338 logical_scan
339 } else {
340 return Ok(None);
341 };
342 let table = logical_scan.table();
343 let output_column_ids = logical_scan.output_column_ids();
344
345 let order_col_ids = table.order_column_ids();
348 let dist_key = table.distribution_key.clone();
349 let mut dist_key_in_order_key_pos = vec![];
351 for d in dist_key {
352 let pos = table
353 .order_column_indices()
354 .position(|x| x == d)
355 .expect("dist_key must in order_key");
356 dist_key_in_order_key_pos.push(pos);
357 }
358 let shortest_prefix_len = dist_key_in_order_key_pos
360 .iter()
361 .max()
362 .map_or(0, |pos| pos + 1);
363
364 if shortest_prefix_len == 0 {
366 return Ok(None);
367 }
368
369 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
371 for order_col_id in order_col_ids {
372 let mut found = false;
373 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
374 if order_col_id == output_column_ids[eq_idx] {
375 reorder_idx.push(i);
376 found = true;
377 break;
378 }
379 }
380 if !found {
381 break;
382 }
383 }
384 if reorder_idx.len() < shortest_prefix_len {
385 return Ok(None);
386 }
387 let lookup_prefix_len = reorder_idx.len();
388 let predicate = predicate.reorder(&reorder_idx);
389
390 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
392 let o2r = if let Some(project_expr) = project_expr {
394 project_expr
395 .into_iter()
396 .map(|x| x.as_input_ref().unwrap().index)
397 .collect_vec()
398 } else {
399 (0..logical_scan.output_col_idx().len()).collect_vec()
400 };
401 let left_schema_len = logical_join.left.schema().len();
402
403 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
404 offset: left_schema_len,
405 mapping: o2r.clone(),
406 };
407
408 let new_eq_cond = predicate
409 .eq_cond()
410 .rewrite_expr(&mut join_predicate_rewriter);
411
412 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
413 offset: left_schema_len,
414 };
415
416 let new_other_cond = predicate
417 .other_cond()
418 .clone()
419 .rewrite_expr(&mut join_predicate_rewriter)
420 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
421
422 let new_join_on = new_eq_cond.and(new_other_cond);
423 let new_predicate = EqJoinPredicate::create(
424 left_schema_len,
425 new_scan.schema().len(),
426 new_join_on.clone(),
427 );
428
429 if !new_predicate.has_eq() {
432 return Ok(None);
433 }
434
435 let new_join_output_indices = logical_join
438 .output_indices
439 .iter()
440 .map(|&x| {
441 if x < left_schema_len {
442 x
443 } else {
444 o2r[x - left_schema_len] + left_schema_len
445 }
446 })
447 .collect_vec();
448
449 let new_scan_output_column_ids = new_scan.output_column_ids();
450 let as_of = new_scan.as_of.clone();
451 let new_logical_scan: LogicalScan = new_scan.into();
452
453 let new_logical_join = generic::Join::new(
455 logical_join.left,
456 new_logical_scan.to_batch()?,
457 new_join_on,
458 logical_join.join_type,
459 new_join_output_indices,
460 );
461
462 let asof_desc = self
463 .is_asof_join()
464 .then(|| {
465 Self::get_inequality_desc_from_predicate(
466 predicate.other_cond().clone(),
467 left_schema_len,
468 )
469 })
470 .transpose()?;
471
472 Ok(Some(BatchLookupJoin::new(
473 new_logical_join,
474 new_predicate,
475 table.clone(),
476 new_scan_output_column_ids,
477 lookup_prefix_len,
478 false,
479 as_of,
480 asof_desc,
481 )))
482 }
483
484 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
485 self.core.decompose()
486 }
487}
488
489impl PlanTreeNodeBinary<Logical> for LogicalJoin {
490 fn left(&self) -> PlanRef {
491 self.core.left.clone()
492 }
493
494 fn right(&self) -> PlanRef {
495 self.core.right.clone()
496 }
497
498 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
499 Self::with_core(generic::Join {
500 left,
501 right,
502 ..self.core.clone()
503 })
504 }
505
506 fn rewrite_with_left_right(
507 &self,
508 left: PlanRef,
509 left_col_change: ColIndexMapping,
510 right: PlanRef,
511 right_col_change: ColIndexMapping,
512 ) -> (Self, ColIndexMapping) {
513 let (new_on, new_output_indices) = {
514 let (mut map, _) = left_col_change.clone().into_parts();
515 let (mut right_map, _) = right_col_change.clone().into_parts();
516 for i in right_map.iter_mut().flatten() {
517 *i += left.schema().len();
518 }
519 map.append(&mut right_map);
520 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
521
522 let new_output_indices = self
523 .output_indices()
524 .iter()
525 .map(|&i| mapping.map(i))
526 .collect::<Vec<_>>();
527 let new_on = self.on().clone().rewrite_expr(&mut mapping);
528 (new_on, new_output_indices)
529 };
530
531 let join = Self::with_output_indices(
532 left,
533 right,
534 self.join_type(),
535 new_on,
536 new_output_indices.clone(),
537 );
538
539 let new_i2o = ColIndexMapping::with_remaining_columns(
540 &new_output_indices,
541 join.internal_column_num(),
542 );
543
544 let old_o2i = self.core.o2i_col_mapping();
545
546 let old_o2l = old_o2i
547 .composite(&self.core.i2l_col_mapping())
548 .composite(&left_col_change);
549 let old_o2r = old_o2i
550 .composite(&self.core.i2r_col_mapping())
551 .composite(&right_col_change);
552 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
553 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
554
555 let out_col_change = old_o2l
556 .composite(&new_l2o)
557 .union(&old_o2r.composite(&new_r2o));
558 (join, out_col_change)
559 }
560}
561
562impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
563
564impl ColPrunable for LogicalJoin {
565 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
566 let required_cols = required_cols
568 .iter()
569 .map(|i| self.output_indices()[*i])
570 .collect_vec();
571 let left_len = self.left().schema().fields.len();
572
573 let total_len = self.left().schema().len() + self.right().schema().len();
574 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
575
576 required_cols.iter().for_each(|&i| {
577 if self.is_right_join() {
578 resized_required_cols.insert(left_len + i);
579 } else {
580 resized_required_cols.insert(i);
581 }
582 });
583
584 let mut visitor = CollectInputRef::new(resized_required_cols);
587 self.on().visit_expr(&mut visitor);
588 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
589
590 let mut left_required_cols = Vec::new();
591 let mut right_required_cols = Vec::new();
592 left_right_required_cols.iter().for_each(|&i| {
593 if i < left_len {
594 left_required_cols.push(i);
595 } else {
596 right_required_cols.push(i - left_len);
597 }
598 });
599
600 let mut on = self.on().clone();
601 let mut mapping =
602 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
603 on = on.rewrite_expr(&mut mapping);
604
605 let new_output_indices = {
606 let required_inputs_in_output = if self.is_left_join() {
607 &left_required_cols
608 } else if self.is_right_join() {
609 &right_required_cols
610 } else {
611 &left_right_required_cols
612 };
613
614 let mapping =
615 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
616 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
617 };
618
619 LogicalJoin::with_output_indices(
620 self.left().prune_col(&left_required_cols, ctx),
621 self.right().prune_col(&right_required_cols, ctx),
622 self.join_type(),
623 on,
624 new_output_indices,
625 )
626 .into()
627 }
628}
629
630impl ExprRewritable<Logical> for LogicalJoin {
631 fn has_rewritable_expr(&self) -> bool {
632 true
633 }
634
635 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
636 let mut core = self.core.clone();
637 core.rewrite_exprs(r);
638 Self {
639 base: self.base.clone_with_new_plan_id(),
640 core,
641 }
642 .into()
643 }
644}
645
646impl ExprVisitable for LogicalJoin {
647 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
648 self.core.visit_exprs(v);
649 }
650}
651
652fn derive_predicate_from_eq_condition(
670 expr: &ExprImpl,
671 eq_condition: &EqJoinPredicate,
672 col_num: usize,
673 expr_is_left: bool,
674) -> Option<ExprImpl> {
675 if expr.is_impure() {
676 return None;
677 }
678 let eq_indices = eq_condition
679 .eq_indexes_typed()
680 .iter()
681 .filter_map(|(l, r)| {
682 if l.return_type() != r.return_type() {
683 None
684 } else if expr_is_left {
685 Some(l.index())
686 } else {
687 Some(r.index())
688 }
689 })
690 .collect_vec();
691 if expr
692 .collect_input_refs(col_num)
693 .ones()
694 .any(|index| !eq_indices.contains(&index))
695 {
696 return None;
698 }
699 let other_side_mapping = if expr_is_left {
702 eq_condition.eq_indexes_typed().into_iter().collect()
703 } else {
704 eq_condition
705 .eq_indexes_typed()
706 .into_iter()
707 .map(|(x, y)| (y, x))
708 .collect()
709 };
710 struct InputRefsRewriter {
711 mapping: HashMap<InputRef, InputRef>,
712 }
713 impl ExprRewriter for InputRefsRewriter {
714 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
715 self.mapping[&input_ref].clone().into()
716 }
717 }
718 Some(
719 InputRefsRewriter {
720 mapping: other_side_mapping,
721 }
722 .rewrite_expr(expr.clone()),
723 )
724}
725
726struct LookupJoinPredicateRewriter {
728 offset: usize,
729 mapping: Vec<usize>,
730}
731impl ExprRewriter for LookupJoinPredicateRewriter {
732 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
733 if input_ref.index() < self.offset {
734 input_ref.into()
735 } else {
736 InputRef::new(
737 self.mapping[input_ref.index() - self.offset] + self.offset,
738 input_ref.return_type(),
739 )
740 .into()
741 }
742 }
743}
744
745struct LookupJoinScanPredicateRewriter {
747 offset: usize,
748}
749impl ExprRewriter for LookupJoinScanPredicateRewriter {
750 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
751 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
752 }
753}
754
755impl PredicatePushdown for LogicalJoin {
756 fn predicate_pushdown(
780 &self,
781 predicate: Condition,
782 ctx: &mut PredicatePushdownContext,
783 ) -> PlanRef {
784 let mut predicate = {
786 let mut mapping = self.core.o2i_col_mapping();
787 predicate.rewrite_expr(&mut mapping)
788 };
789
790 let left_col_num = self.left().schema().len();
791 let right_col_num = self.right().schema().len();
792 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
793
794 let push_down_temporal_predicate = !self.should_be_temporal_join();
795
796 let (left_from_filter, right_from_filter, on) = push_down_into_join(
797 &mut predicate,
798 left_col_num,
799 right_col_num,
800 join_type,
801 push_down_temporal_predicate,
802 );
803
804 let mut new_on = self.on().clone().and(on);
805 let (left_from_on, right_from_on) = push_down_join_condition(
806 &mut new_on,
807 left_col_num,
808 right_col_num,
809 join_type,
810 push_down_temporal_predicate,
811 );
812
813 let left_predicate = left_from_filter.and(left_from_on);
814 let right_predicate = right_from_filter.and(right_from_on);
815
816 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
818
819 let right_from_left = if matches!(
821 join_type,
822 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
823 ) {
824 Condition {
825 conjunctions: left_predicate
826 .conjunctions
827 .iter()
828 .filter_map(|expr| {
829 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
830 })
831 .collect(),
832 }
833 } else {
834 Condition::true_cond()
835 };
836
837 let left_from_right = if matches!(
839 join_type,
840 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
841 ) {
842 Condition {
843 conjunctions: right_predicate
844 .conjunctions
845 .iter()
846 .filter_map(|expr| {
847 derive_predicate_from_eq_condition(
848 expr,
849 &eq_condition,
850 right_col_num,
851 false,
852 )
853 })
854 .collect(),
855 }
856 } else {
857 Condition::true_cond()
858 };
859
860 let left_predicate = left_predicate.and(left_from_right);
861 let right_predicate = right_predicate.and(right_from_left);
862
863 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
864 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
865 let new_join = LogicalJoin::with_output_indices(
866 new_left,
867 new_right,
868 join_type,
869 new_on,
870 self.output_indices().clone(),
871 );
872
873 let mut mapping = self.core.i2o_col_mapping();
874 predicate = predicate.rewrite_expr(&mut mapping);
875 LogicalFilter::create(new_join.into(), predicate)
876 }
877}
878
879impl LogicalJoin {
880 fn get_stream_input_for_hash_join(
881 &self,
882 predicate: &EqJoinPredicate,
883 ctx: &mut ToStreamContext,
884 ) -> Result<(StreamPlanRef, StreamPlanRef)> {
885 use super::stream::prelude::*;
886
887 let mut right = self.right().to_stream_with_dist_required(
888 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
889 ctx,
890 )?;
891 let logical_left = self.left();
892
893 let r2l =
894 predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
895 let l2r =
896 predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
897
898 let mut left;
899 let right_dist = right.distribution();
900 match right_dist {
901 Distribution::HashShard(_) => {
902 let left_dist = r2l
903 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
904 left = self
905 .core
906 .left
907 .to_stream_with_dist_required(&left_dist, ctx)?;
908 }
909 Distribution::UpstreamHashShard(_, _) => {
910 left = self.core.left.to_stream_with_dist_required(
911 &RequiredDist::shard_by_key(
912 self.left().schema().len(),
913 &predicate.left_eq_indexes(),
914 ),
915 ctx,
916 )?;
917 let left_dist = left.distribution();
918 match left_dist {
919 Distribution::HashShard(_) => {
920 let right_dist = l2r.rewrite_required_distribution(
921 &RequiredDist::PhysicalDist(left_dist.clone()),
922 );
923 right = right_dist.streaming_enforce_if_not_satisfies(right)?
924 }
925 Distribution::UpstreamHashShard(_, _) => {
926 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
927 .streaming_enforce_if_not_satisfies(left)?;
928 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
929 .streaming_enforce_if_not_satisfies(right)?;
930 }
931 _ => unreachable!(),
932 }
933 }
934 _ => unreachable!(),
935 }
936 Ok((left, right))
937 }
938
939 fn to_stream_hash_join(
940 &self,
941 predicate: EqJoinPredicate,
942 ctx: &mut ToStreamContext,
943 ) -> Result<StreamPlanRef> {
944 use super::stream::prelude::*;
945
946 assert!(predicate.has_eq());
947 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
948
949 let core = self.core.clone_with_inputs(left, right);
950
951 let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
960
961 let force_filter_inside_join = self
962 .base
963 .ctx()
964 .session_ctx()
965 .config()
966 .streaming_force_filter_inside_join();
967
968 let pull_filter = self.join_type() == JoinType::Inner
969 && stream_hash_join.eq_join_predicate().has_non_eq()
970 && stream_hash_join.inequality_pairs().is_empty()
971 && (!force_filter_inside_join);
972 if pull_filter {
973 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
974
975 let mut core = core.clone();
976 core.output_indices = default_indices.clone();
977 let eq_cond = EqJoinPredicate::new(
979 Condition::true_cond(),
980 predicate.eq_keys().to_vec(),
981 self.left().schema().len(),
982 self.right().schema().len(),
983 );
984 core.on = eq_cond.eq_cond();
985 let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
986 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
987 let plan = StreamFilter::new(logical_filter).into();
988 if self.output_indices() != &default_indices {
989 let logical_project = generic::Project::with_mapping(
990 plan,
991 ColIndexMapping::with_remaining_columns(
992 self.output_indices(),
993 self.internal_column_num(),
994 ),
995 );
996 Ok(StreamProject::new(logical_project).into())
997 } else {
998 Ok(plan)
999 }
1000 } else {
1001 Ok(stream_hash_join.into())
1002 }
1003 }
1004
1005 fn should_be_temporal_join(&self) -> bool {
1006 let right = self.right();
1007 if let Some(logical_scan) = right.as_logical_scan() {
1008 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1009 } else {
1010 false
1011 }
1012 }
1013
1014 fn to_stream_temporal_join_with_index_selection(
1015 &self,
1016 predicate: EqJoinPredicate,
1017 ctx: &mut ToStreamContext,
1018 ) -> Result<StreamPlanRef> {
1019 let right = self.right();
1021 let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1023
1024 let mut result_plan: Result<StreamTemporalJoin> =
1026 self.to_stream_temporal_join(predicate.clone(), ctx);
1027 if let Ok(temporal_join) = &result_plan
1029 && temporal_join.eq_join_predicate().eq_indexes().len()
1030 == logical_scan.primary_key().len()
1031 {
1032 return result_plan.map(|x| x.into());
1033 }
1034 let indexes = logical_scan.table_indexes();
1035 for index in indexes {
1036 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1038 let index_scan: PlanRef = index_scan.into();
1039 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1040 if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1041 match &result_plan {
1042 Err(_) => result_plan = Ok(temporal_join),
1043 Ok(prev_temporal_join) => {
1044 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1046 < temporal_join.eq_join_predicate().eq_indexes().len()
1047 {
1048 result_plan = Ok(temporal_join)
1049 }
1050 }
1051 }
1052 }
1053 }
1054 }
1055
1056 result_plan.map(|x| x.into())
1057 }
1058
1059 fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1060 let Some(logical_scan) = right.as_logical_scan() else {
1061 return Err(RwError::from(ErrorCode::NotSupported(
1062 "Temporal join requires a table scan as its lookup table".into(),
1063 "Please provide a table scan".into(),
1064 )));
1065 };
1066
1067 if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1068 return Err(RwError::from(ErrorCode::NotSupported(
1069 "Temporal join requires a table defined as temporal table".into(),
1070 "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1071 )));
1072 }
1073 Ok(logical_scan)
1074 }
1075
1076 fn temporal_join_scan_predicate_pull_up(
1077 logical_scan: &LogicalScan,
1078 predicate: EqJoinPredicate,
1079 output_indices: &[usize],
1080 left_schema_len: usize,
1081 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1082 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1084 let o2r = if let Some(project_expr) = project_expr {
1086 project_expr
1087 .into_iter()
1088 .map(|x| x.as_input_ref().unwrap().index)
1089 .collect_vec()
1090 } else {
1091 (0..logical_scan.output_col_idx().len()).collect_vec()
1092 };
1093 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1094 offset: left_schema_len,
1095 mapping: o2r.clone(),
1096 };
1097
1098 let new_eq_cond = predicate
1099 .eq_cond()
1100 .rewrite_expr(&mut join_predicate_rewriter);
1101
1102 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1103 offset: left_schema_len,
1104 };
1105
1106 let new_other_cond = predicate
1107 .other_cond()
1108 .clone()
1109 .rewrite_expr(&mut join_predicate_rewriter)
1110 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1111
1112 let new_join_on = new_eq_cond.and(new_other_cond);
1113
1114 let new_predicate = EqJoinPredicate::create(
1115 left_schema_len,
1116 new_scan.schema().len(),
1117 new_join_on.clone(),
1118 );
1119
1120 let new_join_output_indices = output_indices
1123 .iter()
1124 .map(|&x| {
1125 if x < left_schema_len {
1126 x
1127 } else {
1128 o2r[x - left_schema_len] + left_schema_len
1129 }
1130 })
1131 .collect_vec();
1132 let new_stream_table_scan =
1134 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1135 Ok((
1136 new_stream_table_scan,
1137 new_predicate,
1138 new_join_on,
1139 new_join_output_indices,
1140 ))
1141 }
1142
1143 fn to_stream_temporal_join(
1144 &self,
1145 predicate: EqJoinPredicate,
1146 ctx: &mut ToStreamContext,
1147 ) -> Result<StreamTemporalJoin> {
1148 use super::stream::prelude::*;
1149
1150 assert!(predicate.has_eq());
1151
1152 let right = self.right();
1153
1154 let logical_scan = Self::check_temporal_rhs(&right)?;
1155
1156 let table = logical_scan.table();
1157 let output_column_ids = logical_scan.output_column_ids();
1158
1159 let order_col_ids = table.order_column_ids();
1162 let dist_key = table.distribution_key.clone();
1163
1164 let mut dist_key_in_order_key_pos = vec![];
1165 for d in dist_key {
1166 let pos = table
1167 .order_column_indices()
1168 .position(|x| x == d)
1169 .expect("dist_key must in order_key");
1170 dist_key_in_order_key_pos.push(pos);
1171 }
1172 let shortest_prefix_len = dist_key_in_order_key_pos
1174 .iter()
1175 .max()
1176 .map_or(0, |pos| pos + 1);
1177
1178 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1180 for order_col_id in order_col_ids {
1181 let mut found = false;
1182 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1183 if order_col_id == output_column_ids[eq_idx] {
1184 reorder_idx.push(i);
1185 found = true;
1186 break;
1187 }
1188 }
1189 if !found {
1190 break;
1191 }
1192 }
1193 if reorder_idx.len() < shortest_prefix_len {
1194 return Err(RwError::from(ErrorCode::NotSupported(
1196 "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1197 "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1198 )));
1199 }
1200 let lookup_prefix_len = reorder_idx.len();
1201 let predicate = predicate.reorder(&reorder_idx);
1202
1203 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1204 RequiredDist::single()
1205 } else {
1206 let left_eq_indexes = predicate.left_eq_indexes();
1207 let left_dist_key = dist_key_in_order_key_pos
1208 .iter()
1209 .map(|pos| left_eq_indexes[*pos])
1210 .collect_vec();
1211
1212 RequiredDist::hash_shard(&left_dist_key)
1213 };
1214
1215 let left = self.left().to_stream(ctx)?;
1216 let left = required_dist.stream_enforce(left);
1218
1219 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1220 Self::temporal_join_scan_predicate_pull_up(
1221 logical_scan,
1222 predicate,
1223 self.output_indices(),
1224 self.left().schema().len(),
1225 )?;
1226
1227 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1228 if !new_predicate.has_eq() {
1229 return Err(RwError::from(ErrorCode::NotSupported(
1230 "Temporal join requires a non trivial join condition".into(),
1231 "Please remove the false condition of the join".into(),
1232 )));
1233 }
1234
1235 let new_logical_join = generic::Join::new(
1237 left,
1238 right,
1239 new_join_on,
1240 self.join_type(),
1241 new_join_output_indices,
1242 );
1243
1244 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1245
1246 StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1247 }
1248
1249 fn to_stream_nested_loop_temporal_join(
1250 &self,
1251 predicate: EqJoinPredicate,
1252 ctx: &mut ToStreamContext,
1253 ) -> Result<StreamPlanRef> {
1254 use super::stream::prelude::*;
1255 assert!(!predicate.has_eq());
1256
1257 let left = self.left().to_stream_with_dist_required(
1258 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1259 ctx,
1260 )?;
1261 assert!(left.as_stream_exchange().is_some());
1262
1263 if self.join_type() != JoinType::Inner {
1264 return Err(RwError::from(ErrorCode::NotSupported(
1265 "Temporal join requires an inner join".into(),
1266 "Please use an inner join".into(),
1267 )));
1268 }
1269
1270 if !left.append_only() {
1271 return Err(RwError::from(ErrorCode::NotSupported(
1272 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1273 "Please ensure the left hash side is append only".into(),
1274 )));
1275 }
1276
1277 let right = self.right();
1278 let logical_scan = Self::check_temporal_rhs(&right)?;
1279
1280 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1281 Self::temporal_join_scan_predicate_pull_up(
1282 logical_scan,
1283 predicate,
1284 self.output_indices(),
1285 self.left().schema().len(),
1286 )?;
1287
1288 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1289
1290 let new_logical_join = generic::Join::new(
1292 left,
1293 right,
1294 new_join_on,
1295 self.join_type(),
1296 new_join_output_indices,
1297 );
1298
1299 Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1300 }
1301
1302 fn to_stream_dynamic_filter(
1303 &self,
1304 predicate: Condition,
1305 ctx: &mut ToStreamContext,
1306 ) -> Result<Option<StreamPlanRef>> {
1307 use super::stream::prelude::*;
1308
1309 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1315 return Ok(None);
1316 }
1317
1318 if !self.right().max_one_row() {
1320 return Ok(None);
1321 }
1322 if self.right().schema().len() != 1 {
1323 return Ok(None);
1324 }
1325
1326 if predicate.conjunctions.len() > 1 {
1328 return Ok(None);
1329 }
1330 let expr: ExprImpl = predicate.into();
1331 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1332 Some(v) => v,
1333 None => return Ok(None),
1334 };
1335
1336 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1337 && right_ref.index == self.left().schema().len() ;
1338 if !condition_cross_inputs {
1339 return Ok(None);
1341 }
1342
1343 if self.left().schema().fields()[left_ref.index].data_type
1345 != self.right().schema().fields()[0].data_type
1346 {
1347 return Ok(None);
1348 }
1349
1350 let all_output_from_left = self
1352 .output_indices()
1353 .iter()
1354 .all(|i| *i < self.left().schema().len());
1355 if !all_output_from_left {
1356 return Ok(None);
1357 }
1358
1359 let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1360 let right = self.right().to_stream_with_dist_required(
1361 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1362 ctx,
1363 )?;
1364
1365 assert!(right.as_stream_exchange().is_some());
1366 assert_eq!(
1367 *right.inputs().iter().exactly_one().unwrap().distribution(),
1368 Distribution::Single
1369 );
1370
1371 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1372 let plan = StreamDynamicFilter::new(core)?.into();
1373 if self
1375 .output_indices()
1376 .iter()
1377 .copied()
1378 .ne(0..self.left().schema().len())
1379 {
1380 let logical_project = generic::Project::with_mapping(
1383 plan,
1384 ColIndexMapping::with_remaining_columns(
1385 self.output_indices(),
1386 self.left().schema().len(),
1387 ),
1388 );
1389 Ok(Some(StreamProject::new(logical_project).into()))
1390 } else {
1391 Ok(Some(plan))
1392 }
1393 }
1394
1395 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1396 let predicate = EqJoinPredicate::create(
1397 self.left().schema().len(),
1398 self.right().schema().len(),
1399 self.on().clone(),
1400 );
1401 assert!(predicate.has_eq());
1402
1403 let join = self
1404 .core
1405 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1406
1407 Ok(self
1408 .to_batch_lookup_join(predicate, join)?
1409 .expect("Fail to convert to lookup join")
1410 .into())
1411 }
1412
1413 fn to_stream_asof_join(
1414 &self,
1415 predicate: EqJoinPredicate,
1416 ctx: &mut ToStreamContext,
1417 ) -> Result<StreamPlanRef> {
1418 use super::stream::prelude::*;
1419
1420 if predicate.eq_keys().is_empty() {
1421 return Err(ErrorCode::InvalidInputSyntax(
1422 "AsOf join requires at least 1 equal condition".to_owned(),
1423 )
1424 .into());
1425 }
1426
1427 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1428 let left_len = left.schema().len();
1429 let core = self.core.clone_with_inputs(left, right);
1430
1431 let inequality_desc =
1432 Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1433
1434 Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1435 }
1436
1437 fn to_batch_hash_join(
1439 &self,
1440 logical_join: generic::Join<BatchPlanRef>,
1441 predicate: EqJoinPredicate,
1442 ) -> Result<BatchPlanRef> {
1443 use super::batch::prelude::*;
1444
1445 let left_schema_len = logical_join.left.schema().len();
1446 let asof_desc = self
1447 .is_asof_join()
1448 .then(|| {
1449 Self::get_inequality_desc_from_predicate(
1450 predicate.other_cond().clone(),
1451 left_schema_len,
1452 )
1453 })
1454 .transpose()?;
1455
1456 let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1457 Ok(batch_join.into())
1458 }
1459
1460 pub fn get_inequality_desc_from_predicate(
1461 predicate: Condition,
1462 left_input_len: usize,
1463 ) -> Result<AsOfJoinDesc> {
1464 let expr: ExprImpl = predicate.into();
1465 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1466 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1467 {
1468 Ok(AsOfJoinDesc {
1469 left_idx: left_input_ref.index() as u32,
1470 right_idx: (right_input_ref.index() - left_input_len) as u32,
1471 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1472 })
1473 } else {
1474 bail!("inequal condition from the same side should be push down in optimizer");
1475 }
1476 } else {
1477 Err(ErrorCode::InvalidInputSyntax(
1478 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1479 )
1480 .into())
1481 }
1482 }
1483
1484 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1485 match expr_type {
1486 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1487 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1488 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1489 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1490 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1491 "Invalid comparison type: {}",
1492 expr_type.as_str_name()
1493 ))
1494 .into()),
1495 }
1496 }
1497}
1498
1499impl ToBatch for LogicalJoin {
1500 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1501 let predicate = EqJoinPredicate::create(
1502 self.left().schema().len(),
1503 self.right().schema().len(),
1504 self.on().clone(),
1505 );
1506
1507 let batch_join = self
1508 .core
1509 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1510
1511 let ctx = self.base.ctx();
1512 let config = ctx.session_ctx().config();
1513
1514 if predicate.has_eq() {
1515 if !predicate.eq_keys_are_type_aligned() {
1516 return Err(ErrorCode::InternalError(format!(
1517 "Join eq keys are not aligned for predicate: {predicate:?}"
1518 ))
1519 .into());
1520 }
1521 if config.batch_enable_lookup_join()
1522 && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1523 predicate.clone(),
1524 batch_join.clone(),
1525 )?
1526 {
1527 return Ok(lookup_join.into());
1528 }
1529 self.to_batch_hash_join(batch_join, predicate)
1530 } else if self.is_asof_join() {
1531 Err(ErrorCode::InvalidInputSyntax(
1532 "AsOf join requires at least 1 equal condition".to_owned(),
1533 )
1534 .into())
1535 } else {
1536 Ok(BatchNestedLoopJoin::new(batch_join).into())
1538 }
1539 }
1540}
1541
1542impl ToStream for LogicalJoin {
1543 fn to_stream(
1544 &self,
1545 ctx: &mut ToStreamContext,
1546 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1547 if self
1548 .on()
1549 .conjunctions
1550 .iter()
1551 .any(|cond| cond.count_nows() > 0)
1552 {
1553 return Err(ErrorCode::NotSupported(
1554 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1555 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1556 }
1557
1558 let predicate = EqJoinPredicate::create(
1559 self.left().schema().len(),
1560 self.right().schema().len(),
1561 self.on().clone(),
1562 );
1563
1564 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1565 self.to_stream_asof_join(predicate, ctx)
1566 } else if predicate.has_eq() {
1567 if !predicate.eq_keys_are_type_aligned() {
1568 return Err(ErrorCode::InternalError(format!(
1569 "Join eq keys are not aligned for predicate: {predicate:?}"
1570 ))
1571 .into());
1572 }
1573
1574 if self.should_be_temporal_join() {
1575 self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1576 } else {
1577 self.to_stream_hash_join(predicate, ctx)
1578 }
1579 } else if self.should_be_temporal_join() {
1580 self.to_stream_nested_loop_temporal_join(predicate, ctx)
1581 } else if let Some(dynamic_filter) =
1582 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1583 {
1584 Ok(dynamic_filter)
1585 } else {
1586 Err(RwError::from(ErrorCode::NotSupported(
1587 "streaming nested-loop join".to_owned(),
1588 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1589 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1590 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1591 )))
1592 }
1593 }
1594
1595 fn logical_rewrite_for_stream(
1596 &self,
1597 ctx: &mut RewriteStreamContext,
1598 ) -> Result<(PlanRef, ColIndexMapping)> {
1599 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1600 let left_len = left.schema().len();
1601 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1602 let (join, out_col_change) = self.rewrite_with_left_right(
1603 left.clone(),
1604 left_col_change,
1605 right.clone(),
1606 right_col_change,
1607 );
1608
1609 let mapping = ColIndexMapping::with_remaining_columns(
1610 join.output_indices(),
1611 join.internal_column_num(),
1612 );
1613
1614 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1615 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1616
1617 let mut left_to_add = left
1619 .expect_stream_key()
1620 .iter()
1621 .cloned()
1622 .filter(|i| l2o.try_map(*i).is_none())
1623 .collect_vec();
1624
1625 let mut right_to_add = right
1626 .expect_stream_key()
1627 .iter()
1628 .filter(|&&i| r2o.try_map(i).is_none())
1629 .map(|&i| i + left_len)
1630 .collect_vec();
1631
1632 let right_len = right.schema().len();
1635 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1636
1637 let either_or_both = self.core.add_which_join_key_to_pk();
1638
1639 for (lk, rk) in eq_predicate.eq_indexes() {
1640 match either_or_both {
1641 EitherOrBoth::Left(_) => {
1642 if l2o.try_map(lk).is_none() {
1643 left_to_add.push(lk);
1644 }
1645 }
1646 EitherOrBoth::Right(_) => {
1647 if r2o.try_map(rk).is_none() {
1648 right_to_add.push(rk + left_len)
1649 }
1650 }
1651 EitherOrBoth::Both(_, _) => {
1652 if l2o.try_map(lk).is_none() {
1653 left_to_add.push(lk);
1654 }
1655 if r2o.try_map(rk).is_none() {
1656 right_to_add.push(rk + left_len)
1657 }
1658 }
1659 };
1660 }
1661 let left_to_add = left_to_add.into_iter().unique();
1662 let right_to_add = right_to_add.into_iter().unique();
1663 let mut new_output_indices = join.output_indices().clone();
1666 if !join.is_right_join() {
1667 new_output_indices.extend(left_to_add);
1668 }
1669 if !join.is_left_join() {
1670 new_output_indices.extend(right_to_add);
1671 }
1672
1673 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1674
1675 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1676 let l2o = join_with_pk
1679 .core
1680 .l2i_col_mapping()
1681 .composite(&join_with_pk.core.i2o_col_mapping());
1682 let r2o = join_with_pk
1683 .core
1684 .r2i_col_mapping()
1685 .composite(&join_with_pk.core.i2o_col_mapping());
1686 let left_right_stream_keys = join_with_pk
1687 .left()
1688 .expect_stream_key()
1689 .iter()
1690 .map(|i| l2o.map(*i))
1691 .chain(
1692 join_with_pk
1693 .right()
1694 .expect_stream_key()
1695 .iter()
1696 .map(|i| r2o.map(*i)),
1697 )
1698 .collect_vec();
1699 let plan: PlanRef = join_with_pk.into();
1700 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1701 } else {
1702 join_with_pk.into()
1703 };
1704
1705 Ok((plan, out_col_change))
1707 }
1708}
1709
1710#[cfg(test)]
1711mod tests {
1712
1713 use std::collections::HashSet;
1714
1715 use risingwave_common::catalog::{Field, Schema};
1716 use risingwave_common::types::{DataType, Datum};
1717 use risingwave_pb::expr::expr_node::Type;
1718
1719 use super::*;
1720 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1721 use crate::optimizer::optimizer_context::OptimizerContext;
1722 use crate::optimizer::plan_node::LogicalValues;
1723 use crate::optimizer::property::FunctionalDependency;
1724
1725 #[tokio::test]
1739 async fn test_prune_join() {
1740 let ty = DataType::Int32;
1741 let ctx = OptimizerContext::mock().await;
1742 let fields: Vec<Field> = (1..7)
1743 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1744 .collect();
1745 let left = LogicalValues::new(
1746 vec![],
1747 Schema {
1748 fields: fields[0..3].to_vec(),
1749 },
1750 ctx.clone(),
1751 );
1752 let right = LogicalValues::new(
1753 vec![],
1754 Schema {
1755 fields: fields[3..6].to_vec(),
1756 },
1757 ctx,
1758 );
1759 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1760 FunctionCall::new(
1761 Type::Equal,
1762 vec![
1763 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1764 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1765 ],
1766 )
1767 .unwrap(),
1768 ));
1769 let join_type = JoinType::Inner;
1770 let join: PlanRef = LogicalJoin::new(
1771 left.into(),
1772 right.into(),
1773 join_type,
1774 Condition::with_expr(on),
1775 )
1776 .into();
1777
1778 let required_cols = vec![2, 3];
1780 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1781
1782 let join = plan.as_logical_join().unwrap();
1784 assert_eq!(join.schema().fields().len(), 2);
1785 assert_eq!(join.schema().fields()[0], fields[2]);
1786 assert_eq!(join.schema().fields()[1], fields[3]);
1787
1788 let expr: ExprImpl = join.on().clone().into();
1789 let call = expr.as_function_call().unwrap();
1790 assert_eq_input_ref!(&call.inputs()[0], 0);
1791 assert_eq_input_ref!(&call.inputs()[1], 2);
1792
1793 let left = join.left();
1794 let left = left.as_logical_values().unwrap();
1795 assert_eq!(left.schema().fields(), &fields[1..3]);
1796 let right = join.right();
1797 let right = right.as_logical_values().unwrap();
1798 assert_eq!(right.schema().fields(), &fields[3..4]);
1799 }
1800
1801 #[tokio::test]
1803 async fn test_prune_semi_join() {
1804 let ty = DataType::Int32;
1805 let ctx = OptimizerContext::mock().await;
1806 let fields: Vec<Field> = (1..7)
1807 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1808 .collect();
1809 let left = LogicalValues::new(
1810 vec![],
1811 Schema {
1812 fields: fields[0..3].to_vec(),
1813 },
1814 ctx.clone(),
1815 );
1816 let right = LogicalValues::new(
1817 vec![],
1818 Schema {
1819 fields: fields[3..6].to_vec(),
1820 },
1821 ctx,
1822 );
1823 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1824 FunctionCall::new(
1825 Type::Equal,
1826 vec![
1827 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1828 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1829 ],
1830 )
1831 .unwrap(),
1832 ));
1833 for join_type in [
1834 JoinType::LeftSemi,
1835 JoinType::RightSemi,
1836 JoinType::LeftAnti,
1837 JoinType::RightAnti,
1838 ] {
1839 let join = LogicalJoin::new(
1840 left.clone().into(),
1841 right.clone().into(),
1842 join_type,
1843 Condition::with_expr(on.clone()),
1844 );
1845
1846 let offset = if join.is_right_join() { 3 } else { 0 };
1847 let join: PlanRef = join.into();
1848 let required_cols = vec![0];
1850 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1852 let as_plan = plan.as_logical_join().unwrap();
1853 assert_eq!(as_plan.schema().fields().len(), 1);
1855 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1856
1857 let required_cols = vec![0, 1, 2];
1859 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1861 let as_plan = plan.as_logical_join().unwrap();
1862 assert_eq!(as_plan.schema().fields().len(), 3);
1864 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1865 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1866 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1867 }
1868 }
1869
1870 #[tokio::test]
1883 async fn test_prune_join_no_project() {
1884 let ty = DataType::Int32;
1885 let ctx = OptimizerContext::mock().await;
1886 let fields: Vec<Field> = (1..7)
1887 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1888 .collect();
1889 let left = LogicalValues::new(
1890 vec![],
1891 Schema {
1892 fields: fields[0..3].to_vec(),
1893 },
1894 ctx.clone(),
1895 );
1896 let right = LogicalValues::new(
1897 vec![],
1898 Schema {
1899 fields: fields[3..6].to_vec(),
1900 },
1901 ctx,
1902 );
1903 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1904 FunctionCall::new(
1905 Type::Equal,
1906 vec![
1907 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1908 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1909 ],
1910 )
1911 .unwrap(),
1912 ));
1913 let join_type = JoinType::Inner;
1914 let join: PlanRef = LogicalJoin::new(
1915 left.into(),
1916 right.into(),
1917 join_type,
1918 Condition::with_expr(on),
1919 )
1920 .into();
1921
1922 let required_cols = vec![1, 3];
1924 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1925
1926 let join = plan.as_logical_join().unwrap();
1928 assert_eq!(join.schema().fields().len(), 2);
1929 assert_eq!(join.schema().fields()[0], fields[1]);
1930 assert_eq!(join.schema().fields()[1], fields[3]);
1931
1932 let expr: ExprImpl = join.on().clone().into();
1933 let call = expr.as_function_call().unwrap();
1934 assert_eq_input_ref!(&call.inputs()[0], 0);
1935 assert_eq_input_ref!(&call.inputs()[1], 1);
1936
1937 let left = join.left();
1938 let left = left.as_logical_values().unwrap();
1939 assert_eq!(left.schema().fields(), &fields[1..2]);
1940 let right = join.right();
1941 let right = right.as_logical_values().unwrap();
1942 assert_eq!(right.schema().fields(), &fields[3..4]);
1943 }
1944
1945 #[tokio::test]
1959 async fn test_join_to_batch() {
1960 let ctx = OptimizerContext::mock().await;
1961 let fields: Vec<Field> = (1..7)
1962 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1963 .collect();
1964 let left = LogicalValues::new(
1965 vec![],
1966 Schema {
1967 fields: fields[0..3].to_vec(),
1968 },
1969 ctx.clone(),
1970 );
1971 let right = LogicalValues::new(
1972 vec![],
1973 Schema {
1974 fields: fields[3..6].to_vec(),
1975 },
1976 ctx,
1977 );
1978
1979 fn input_ref(i: usize) -> ExprImpl {
1980 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1981 }
1982 let eq_cond = ExprImpl::FunctionCall(Box::new(
1983 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1984 ));
1985 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1986 FunctionCall::new(
1987 Type::Equal,
1988 vec![
1989 input_ref(2),
1990 ExprImpl::Literal(Box::new(Literal::new(
1991 Datum::Some(42_i32.into()),
1992 DataType::Int32,
1993 ))),
1994 ],
1995 )
1996 .unwrap(),
1997 ));
1998 let on_cond = ExprImpl::FunctionCall(Box::new(
2000 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2001 ));
2002
2003 let join_type = JoinType::Inner;
2004 let logical_join = LogicalJoin::new(
2005 left.into(),
2006 right.into(),
2007 join_type,
2008 Condition::with_expr(on_cond),
2009 );
2010
2011 let result = logical_join.to_batch().unwrap();
2013
2014 let hash_join = result.as_batch_hash_join().unwrap();
2016 assert_eq!(
2017 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2018 eq_cond
2019 );
2020 assert_eq!(
2021 *hash_join
2022 .eq_join_predicate()
2023 .non_eq_cond()
2024 .conjunctions
2025 .first()
2026 .unwrap(),
2027 non_eq_cond
2028 );
2029 }
2030
2031 #[tokio::test]
2044 #[ignore] async fn test_join_to_stream() {
2047 }
2115 #[tokio::test]
2129 async fn test_join_column_prune_with_order_required() {
2130 let ty = DataType::Int32;
2131 let ctx = OptimizerContext::mock().await;
2132 let fields: Vec<Field> = (1..7)
2133 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2134 .collect();
2135 let left = LogicalValues::new(
2136 vec![],
2137 Schema {
2138 fields: fields[0..3].to_vec(),
2139 },
2140 ctx.clone(),
2141 );
2142 let right = LogicalValues::new(
2143 vec![],
2144 Schema {
2145 fields: fields[3..6].to_vec(),
2146 },
2147 ctx,
2148 );
2149 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2150 FunctionCall::new(
2151 Type::Equal,
2152 vec![
2153 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2154 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2155 ],
2156 )
2157 .unwrap(),
2158 ));
2159 let join_type = JoinType::Inner;
2160 let join: PlanRef = LogicalJoin::new(
2161 left.into(),
2162 right.into(),
2163 join_type,
2164 Condition::with_expr(on),
2165 )
2166 .into();
2167
2168 let required_cols = vec![3, 2];
2170 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2171
2172 let join = plan.as_logical_join().unwrap();
2174 assert_eq!(join.schema().fields().len(), 2);
2175 assert_eq!(join.schema().fields()[0], fields[3]);
2176 assert_eq!(join.schema().fields()[1], fields[2]);
2177
2178 let expr: ExprImpl = join.on().clone().into();
2179 let call = expr.as_function_call().unwrap();
2180 assert_eq_input_ref!(&call.inputs()[0], 0);
2181 assert_eq_input_ref!(&call.inputs()[1], 2);
2182
2183 let left = join.left();
2184 let left = left.as_logical_values().unwrap();
2185 assert_eq!(left.schema().fields(), &fields[1..3]);
2186 let right = join.right();
2187 let right = right.as_logical_values().unwrap();
2188 assert_eq!(right.schema().fields(), &fields[3..4]);
2189 }
2190
2191 #[tokio::test]
2192 async fn fd_derivation_inner_outer_join() {
2193 let ctx = OptimizerContext::mock().await;
2216 let left = {
2217 let fields: Vec<Field> = vec![
2218 Field::with_name(DataType::Int32, "l0"),
2219 Field::with_name(DataType::Int32, "l1"),
2220 ];
2221 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2222 values
2224 .base
2225 .functional_dependency_mut()
2226 .add_functional_dependency_by_column_indices(&[0], &[1]);
2227 values
2228 };
2229 let right = {
2230 let fields: Vec<Field> = vec![
2231 Field::with_name(DataType::Int32, "r0"),
2232 Field::with_name(DataType::Int32, "r1"),
2233 Field::with_name(DataType::Int32, "r2"),
2234 ];
2235 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2236 values
2238 .base
2239 .functional_dependency_mut()
2240 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2241 values
2242 };
2243 let on: ExprImpl = FunctionCall::new(
2245 Type::And,
2246 vec![
2247 FunctionCall::new(
2248 Type::Equal,
2249 vec![
2250 InputRef::new(0, DataType::Int32).into(),
2251 ExprImpl::literal_int(0),
2252 ],
2253 )
2254 .unwrap()
2255 .into(),
2256 FunctionCall::new(
2257 Type::Equal,
2258 vec![
2259 InputRef::new(1, DataType::Int32).into(),
2260 InputRef::new(3, DataType::Int32).into(),
2261 ],
2262 )
2263 .unwrap()
2264 .into(),
2265 ],
2266 )
2267 .unwrap()
2268 .into();
2269 let expected_fd_set = [
2270 (
2271 JoinType::Inner,
2272 [
2273 FunctionalDependency::with_indices(5, &[0], &[1]),
2275 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2277 FunctionalDependency::with_indices(5, &[], &[0]),
2279 FunctionalDependency::with_indices(5, &[1], &[3]),
2281 FunctionalDependency::with_indices(5, &[3], &[1]),
2282 ]
2283 .into_iter()
2284 .collect::<HashSet<_>>(),
2285 ),
2286 (JoinType::FullOuter, HashSet::new()),
2287 (
2288 JoinType::RightOuter,
2289 [
2290 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2292 ]
2293 .into_iter()
2294 .collect::<HashSet<_>>(),
2295 ),
2296 (
2297 JoinType::LeftOuter,
2298 [
2299 FunctionalDependency::with_indices(5, &[0], &[1]),
2301 ]
2302 .into_iter()
2303 .collect::<HashSet<_>>(),
2304 ),
2305 (
2306 JoinType::LeftSemi,
2307 [
2308 FunctionalDependency::with_indices(2, &[0], &[1]),
2310 ]
2311 .into_iter()
2312 .collect::<HashSet<_>>(),
2313 ),
2314 (
2315 JoinType::LeftAnti,
2316 [
2317 FunctionalDependency::with_indices(2, &[0], &[1]),
2319 ]
2320 .into_iter()
2321 .collect::<HashSet<_>>(),
2322 ),
2323 (
2324 JoinType::RightSemi,
2325 [
2326 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2328 ]
2329 .into_iter()
2330 .collect::<HashSet<_>>(),
2331 ),
2332 (
2333 JoinType::RightAnti,
2334 [
2335 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2337 ]
2338 .into_iter()
2339 .collect::<HashSet<_>>(),
2340 ),
2341 ];
2342
2343 for (join_type, expected_res) in expected_fd_set {
2344 let join = LogicalJoin::new(
2345 left.clone().into(),
2346 right.clone().into(),
2347 join_type,
2348 Condition::with_expr(on.clone()),
2349 );
2350 let fd_set = join
2351 .functional_dependency()
2352 .as_dependencies()
2353 .iter()
2354 .cloned()
2355 .collect::<HashSet<_>>();
2356 assert_eq!(fd_set, expected_res);
2357 }
2358 }
2359}