1use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32 BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33 PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34 ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59 pub base: PlanBase<Logical>,
60 core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64 fn distill<'a>(&self) -> XmlNode<'a> {
65 let verbose = self.base.ctx().is_explain_verbose();
66 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67 vec.push(("type", Pretty::debug(&self.join_type())));
68
69 let concat_schema = self.core.concat_schema();
70 let cond = Pretty::debug(&ConditionDisplay {
71 condition: self.on(),
72 input_schema: &concat_schema,
73 });
74 vec.push(("on", cond));
75
76 if verbose {
77 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78 vec.push(("output", data));
79 }
80
81 childless_record("LogicalJoin", vec)
82 }
83}
84
85impl LogicalJoin {
86 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87 let core = generic::Join::with_full_output(left, right, join_type, on);
88 Self::with_core(core)
89 }
90
91 pub(crate) fn with_output_indices(
92 left: PlanRef,
93 right: PlanRef,
94 join_type: JoinType,
95 on: Condition,
96 output_indices: Vec<usize>,
97 ) -> Self {
98 let core = generic::Join::new(left, right, on, join_type, output_indices);
99 Self::with_core(core)
100 }
101
102 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103 let base = PlanBase::new_logical_with_core(&core);
104 LogicalJoin { base, core }
105 }
106
107 pub fn create(
108 left: PlanRef,
109 right: PlanRef,
110 join_type: JoinType,
111 on_clause: ExprImpl,
112 ) -> PlanRef {
113 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114 }
115
116 pub fn internal_column_num(&self) -> usize {
117 self.core.internal_column_num()
118 }
119
120 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121 self.core.i2l_col_mapping_ignore_join_type()
122 }
123
124 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125 self.core.i2r_col_mapping_ignore_join_type()
126 }
127
128 pub fn on(&self) -> &Condition {
130 &self.core.on
131 }
132
133 pub fn core(&self) -> &generic::Join<PlanRef> {
134 &self.core
135 }
136
137 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
139 let input_refs = self
140 .core
141 .on
142 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
143 let index_group = input_refs
144 .ones()
145 .chunk_by(|i| *i < self.core.left.schema().len());
146 let left_index = index_group
147 .into_iter()
148 .next()
149 .map_or(vec![], |group| group.1.collect_vec());
150 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
151 group
152 .1
153 .map(|i| i - self.core.left.schema().len())
154 .collect_vec()
155 });
156 (left_index, right_index)
157 }
158
159 pub fn join_type(&self) -> JoinType {
161 self.core.join_type
162 }
163
164 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
166 self.core.eq_indexes()
167 }
168
169 pub fn output_indices(&self) -> &Vec<usize> {
171 &self.core.output_indices
172 }
173
174 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
176 Self::with_core(generic::Join {
177 output_indices,
178 ..self.core.clone()
179 })
180 }
181
182 pub fn clone_with_cond(&self, on: Condition) -> Self {
184 Self::with_core(generic::Join {
185 on,
186 ..self.core.clone()
187 })
188 }
189
190 pub fn is_left_join(&self) -> bool {
191 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
192 }
193
194 pub fn is_right_join(&self) -> bool {
195 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
196 }
197
198 pub fn is_full_out(&self) -> bool {
199 self.core.is_full_out()
200 }
201
202 pub fn is_asof_join(&self) -> bool {
203 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
204 }
205
206 pub fn output_indices_are_trivial(&self) -> bool {
207 itertools::equal(
208 self.output_indices().iter().cloned(),
209 0..self.internal_column_num(),
210 )
211 }
212
213 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
218 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
219 JoinType::LeftOuter => (false, true),
220 JoinType::RightOuter => (true, false),
221 JoinType::FullOuter => (true, true),
222 _ => return join_type,
223 };
224
225 for expr in &predicate.conjunctions {
226 if let ExprImpl::FunctionCall(func) = expr {
227 match func.func_type() {
228 ExprType::Equal
229 | ExprType::NotEqual
230 | ExprType::LessThan
231 | ExprType::LessThanOrEqual
232 | ExprType::GreaterThan
233 | ExprType::GreaterThanOrEqual => {
234 for input in func.inputs() {
235 if let ExprImpl::InputRef(input) = input {
236 let idx = input.index;
237 if idx < left_col_num {
238 gen_null_in_left = false;
239 } else {
240 gen_null_in_right = false;
241 }
242 }
243 }
244 }
245 _ => {}
246 };
247 }
248 }
249
250 match (gen_null_in_left, gen_null_in_right) {
251 (true, true) => JoinType::FullOuter,
252 (true, false) => JoinType::RightOuter,
253 (false, true) => JoinType::LeftOuter,
254 (false, false) => JoinType::Inner,
255 }
256 }
257
258 fn to_batch_lookup_join_with_index_selection(
262 &self,
263 predicate: EqJoinPredicate,
264 batch_join: generic::Join<BatchPlanRef>,
265 ) -> Result<Option<BatchLookupJoin>> {
266 match batch_join.join_type {
267 JoinType::Inner
268 | JoinType::LeftOuter
269 | JoinType::LeftSemi
270 | JoinType::LeftAnti
271 | JoinType::AsofInner
272 | JoinType::AsofLeftOuter => {}
273 _ => return Ok(None),
274 };
275
276 let right = self.right();
278 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
280 logical_scan
281 } else {
282 return Ok(None);
283 };
284
285 let mut result_plan = None;
286 if let Some(lookup_join) =
288 self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
289 {
290 result_plan = Some(lookup_join);
291 }
292
293 if self
294 .core
295 .ctx()
296 .session_ctx()
297 .config()
298 .enable_index_selection()
299 {
300 let indexes = logical_scan.table_indexes();
301 for index in indexes {
302 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
303 let index_scan: PlanRef = index_scan.into();
304 let that = self.clone_with_left_right(self.left(), index_scan.clone());
305 let mut new_batch_join = batch_join.clone();
306 new_batch_join.right =
307 index_scan.to_batch().expect("index scan failed to batch");
308
309 if let Some(lookup_join) =
311 that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
312 {
313 match &result_plan {
314 None => result_plan = Some(lookup_join),
315 Some(prev_lookup_join) => {
316 if prev_lookup_join.lookup_prefix_len()
318 < lookup_join.lookup_prefix_len()
319 {
320 result_plan = Some(lookup_join)
321 }
322 }
323 }
324 }
325 }
326 }
327 }
328
329 Ok(result_plan)
330 }
331
332 fn to_batch_lookup_join(
334 &self,
335 predicate: EqJoinPredicate,
336 logical_join: generic::Join<BatchPlanRef>,
337 ) -> Result<Option<BatchLookupJoin>> {
338 let logical_scan: &LogicalScan =
339 if let Some(logical_scan) = self.core.right.as_logical_scan() {
340 logical_scan
341 } else {
342 return Ok(None);
343 };
344 Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
345 }
346
347 pub fn gen_batch_lookup_join(
348 logical_scan: &LogicalScan,
349 predicate: EqJoinPredicate,
350 logical_join: generic::Join<BatchPlanRef>,
351 is_as_of: bool,
352 ) -> Result<Option<BatchLookupJoin>> {
353 match logical_join.join_type {
354 JoinType::Inner
355 | JoinType::LeftOuter
356 | JoinType::LeftSemi
357 | JoinType::LeftAnti
358 | JoinType::AsofInner
359 | JoinType::AsofLeftOuter => {}
360 _ => return Ok(None),
361 };
362
363 let table = logical_scan.table();
364 let output_column_ids = logical_scan.output_column_ids();
365
366 let order_col_ids = table.order_column_ids();
369 let dist_key = table.distribution_key.clone();
370 let mut dist_key_in_order_key_pos = vec![];
372 for d in dist_key {
373 let pos = table
374 .order_column_indices()
375 .position(|x| x == d)
376 .expect("dist_key must in order_key");
377 dist_key_in_order_key_pos.push(pos);
378 }
379 let shortest_prefix_len = dist_key_in_order_key_pos
381 .iter()
382 .max()
383 .map_or(0, |pos| pos + 1);
384
385 if shortest_prefix_len == 0 {
387 return Ok(None);
388 }
389
390 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
392 for order_col_id in order_col_ids {
393 let mut found = false;
394 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
395 if order_col_id == output_column_ids[eq_idx] {
396 reorder_idx.push(i);
397 found = true;
398 break;
399 }
400 }
401 if !found {
402 break;
403 }
404 }
405 if reorder_idx.len() < shortest_prefix_len {
406 return Ok(None);
407 }
408 let lookup_prefix_len = reorder_idx.len();
409 let predicate = predicate.reorder(&reorder_idx);
410
411 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
413 let o2r = if let Some(project_expr) = project_expr {
415 project_expr
416 .into_iter()
417 .map(|x| x.as_input_ref().unwrap().index)
418 .collect_vec()
419 } else {
420 (0..logical_scan.output_col_idx().len()).collect_vec()
421 };
422 let left_schema_len = logical_join.left.schema().len();
423
424 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
425 offset: left_schema_len,
426 mapping: o2r.clone(),
427 };
428
429 let new_eq_cond = predicate
430 .eq_cond()
431 .rewrite_expr(&mut join_predicate_rewriter);
432
433 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
434 offset: left_schema_len,
435 };
436
437 let new_other_cond = predicate
438 .other_cond()
439 .clone()
440 .rewrite_expr(&mut join_predicate_rewriter)
441 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
442
443 let new_join_on = new_eq_cond.and(new_other_cond);
444 let new_predicate = EqJoinPredicate::create(
445 left_schema_len,
446 new_scan.schema().len(),
447 new_join_on.clone(),
448 );
449
450 if !new_predicate.has_eq() {
453 return Ok(None);
454 }
455
456 let new_join_output_indices = logical_join
459 .output_indices
460 .iter()
461 .map(|&x| {
462 if x < left_schema_len {
463 x
464 } else {
465 o2r[x - left_schema_len] + left_schema_len
466 }
467 })
468 .collect_vec();
469
470 let new_scan_output_column_ids = new_scan.output_column_ids();
471 let as_of = new_scan.as_of.clone();
472 let new_logical_scan: LogicalScan = new_scan.into();
473
474 let new_logical_join = generic::Join::new(
476 logical_join.left,
477 new_logical_scan.to_batch()?,
478 new_join_on,
479 logical_join.join_type,
480 new_join_output_indices,
481 );
482
483 let asof_desc = is_as_of
484 .then(|| {
485 Self::get_inequality_desc_from_predicate(
486 predicate.other_cond().clone(),
487 left_schema_len,
488 )
489 })
490 .transpose()?;
491
492 Ok(Some(BatchLookupJoin::new(
493 new_logical_join,
494 new_predicate,
495 table.clone(),
496 new_scan_output_column_ids,
497 lookup_prefix_len,
498 false,
499 as_of,
500 asof_desc,
501 )))
502 }
503
504 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
505 self.core.decompose()
506 }
507}
508
509impl PlanTreeNodeBinary<Logical> for LogicalJoin {
510 fn left(&self) -> PlanRef {
511 self.core.left.clone()
512 }
513
514 fn right(&self) -> PlanRef {
515 self.core.right.clone()
516 }
517
518 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
519 Self::with_core(generic::Join {
520 left,
521 right,
522 ..self.core.clone()
523 })
524 }
525
526 fn rewrite_with_left_right(
527 &self,
528 left: PlanRef,
529 left_col_change: ColIndexMapping,
530 right: PlanRef,
531 right_col_change: ColIndexMapping,
532 ) -> (Self, ColIndexMapping) {
533 let (new_on, new_output_indices) = {
534 let (mut map, _) = left_col_change.clone().into_parts();
535 let (mut right_map, _) = right_col_change.clone().into_parts();
536 for i in right_map.iter_mut().flatten() {
537 *i += left.schema().len();
538 }
539 map.append(&mut right_map);
540 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
541
542 let new_output_indices = self
543 .output_indices()
544 .iter()
545 .map(|&i| mapping.map(i))
546 .collect::<Vec<_>>();
547 let new_on = self.on().clone().rewrite_expr(&mut mapping);
548 (new_on, new_output_indices)
549 };
550
551 let join = Self::with_output_indices(
552 left,
553 right,
554 self.join_type(),
555 new_on,
556 new_output_indices.clone(),
557 );
558
559 let new_i2o = ColIndexMapping::with_remaining_columns(
560 &new_output_indices,
561 join.internal_column_num(),
562 );
563
564 let old_o2i = self.core.o2i_col_mapping();
565
566 let old_o2l = old_o2i
567 .composite(&self.core.i2l_col_mapping())
568 .composite(&left_col_change);
569 let old_o2r = old_o2i
570 .composite(&self.core.i2r_col_mapping())
571 .composite(&right_col_change);
572 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
573 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
574
575 let out_col_change = old_o2l
576 .composite(&new_l2o)
577 .union(&old_o2r.composite(&new_r2o));
578 (join, out_col_change)
579 }
580}
581
582impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
583
584impl ColPrunable for LogicalJoin {
585 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
586 let required_cols = required_cols
588 .iter()
589 .map(|i| self.output_indices()[*i])
590 .collect_vec();
591 let left_len = self.left().schema().fields.len();
592
593 let total_len = self.left().schema().len() + self.right().schema().len();
594 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
595
596 required_cols.iter().for_each(|&i| {
597 if self.is_right_join() {
598 resized_required_cols.insert(left_len + i);
599 } else {
600 resized_required_cols.insert(i);
601 }
602 });
603
604 let mut visitor = CollectInputRef::new(resized_required_cols);
607 self.on().visit_expr(&mut visitor);
608 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
609
610 let mut left_required_cols = Vec::new();
611 let mut right_required_cols = Vec::new();
612 left_right_required_cols.iter().for_each(|&i| {
613 if i < left_len {
614 left_required_cols.push(i);
615 } else {
616 right_required_cols.push(i - left_len);
617 }
618 });
619
620 let mut on = self.on().clone();
621 let mut mapping =
622 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
623 on = on.rewrite_expr(&mut mapping);
624
625 let new_output_indices = {
626 let required_inputs_in_output = if self.is_left_join() {
627 &left_required_cols
628 } else if self.is_right_join() {
629 &right_required_cols
630 } else {
631 &left_right_required_cols
632 };
633
634 let mapping =
635 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
636 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
637 };
638
639 LogicalJoin::with_output_indices(
640 self.left().prune_col(&left_required_cols, ctx),
641 self.right().prune_col(&right_required_cols, ctx),
642 self.join_type(),
643 on,
644 new_output_indices,
645 )
646 .into()
647 }
648}
649
650impl ExprRewritable<Logical> for LogicalJoin {
651 fn has_rewritable_expr(&self) -> bool {
652 true
653 }
654
655 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
656 let mut core = self.core.clone();
657 core.rewrite_exprs(r);
658 Self {
659 base: self.base.clone_with_new_plan_id(),
660 core,
661 }
662 .into()
663 }
664}
665
666impl ExprVisitable for LogicalJoin {
667 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
668 self.core.visit_exprs(v);
669 }
670}
671
672fn derive_predicate_from_eq_condition(
690 expr: &ExprImpl,
691 eq_condition: &EqJoinPredicate,
692 col_num: usize,
693 expr_is_left: bool,
694) -> Option<ExprImpl> {
695 if expr.is_impure() {
696 return None;
697 }
698 let eq_indices = eq_condition
699 .eq_indexes_typed()
700 .iter()
701 .filter_map(|(l, r)| {
702 if l.return_type() != r.return_type() {
703 None
704 } else if expr_is_left {
705 Some(l.index())
706 } else {
707 Some(r.index())
708 }
709 })
710 .collect_vec();
711 if expr
712 .collect_input_refs(col_num)
713 .ones()
714 .any(|index| !eq_indices.contains(&index))
715 {
716 return None;
718 }
719 let other_side_mapping = if expr_is_left {
722 eq_condition.eq_indexes_typed().into_iter().collect()
723 } else {
724 eq_condition
725 .eq_indexes_typed()
726 .into_iter()
727 .map(|(x, y)| (y, x))
728 .collect()
729 };
730 struct InputRefsRewriter {
731 mapping: HashMap<InputRef, InputRef>,
732 }
733 impl ExprRewriter for InputRefsRewriter {
734 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
735 self.mapping[&input_ref].clone().into()
736 }
737 }
738 Some(
739 InputRefsRewriter {
740 mapping: other_side_mapping,
741 }
742 .rewrite_expr(expr.clone()),
743 )
744}
745
746struct LookupJoinPredicateRewriter {
748 offset: usize,
749 mapping: Vec<usize>,
750}
751impl ExprRewriter for LookupJoinPredicateRewriter {
752 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
753 if input_ref.index() < self.offset {
754 input_ref.into()
755 } else {
756 InputRef::new(
757 self.mapping[input_ref.index() - self.offset] + self.offset,
758 input_ref.return_type(),
759 )
760 .into()
761 }
762 }
763}
764
765struct LookupJoinScanPredicateRewriter {
767 offset: usize,
768}
769impl ExprRewriter for LookupJoinScanPredicateRewriter {
770 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
771 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
772 }
773}
774
775impl PredicatePushdown for LogicalJoin {
776 fn predicate_pushdown(
800 &self,
801 predicate: Condition,
802 ctx: &mut PredicatePushdownContext,
803 ) -> PlanRef {
804 let mut predicate = {
806 let mut mapping = self.core.o2i_col_mapping();
807 predicate.rewrite_expr(&mut mapping)
808 };
809
810 let left_col_num = self.left().schema().len();
811 let right_col_num = self.right().schema().len();
812 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
813
814 let push_down_temporal_predicate = self.temporal_join_on().is_none();
815
816 let (left_from_filter, right_from_filter, on) = push_down_into_join(
817 &mut predicate,
818 left_col_num,
819 right_col_num,
820 join_type,
821 push_down_temporal_predicate,
822 );
823
824 let mut new_on = self.on().clone().and(on);
825 let (left_from_on, right_from_on) = push_down_join_condition(
826 &mut new_on,
827 left_col_num,
828 right_col_num,
829 join_type,
830 push_down_temporal_predicate,
831 );
832
833 let left_predicate = left_from_filter.and(left_from_on);
834 let right_predicate = right_from_filter.and(right_from_on);
835
836 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
838
839 let right_from_left = if matches!(
841 join_type,
842 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
843 ) {
844 Condition {
845 conjunctions: left_predicate
846 .conjunctions
847 .iter()
848 .filter_map(|expr| {
849 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
850 })
851 .collect(),
852 }
853 } else {
854 Condition::true_cond()
855 };
856
857 let left_from_right = if matches!(
859 join_type,
860 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
861 ) {
862 Condition {
863 conjunctions: right_predicate
864 .conjunctions
865 .iter()
866 .filter_map(|expr| {
867 derive_predicate_from_eq_condition(
868 expr,
869 &eq_condition,
870 right_col_num,
871 false,
872 )
873 })
874 .collect(),
875 }
876 } else {
877 Condition::true_cond()
878 };
879
880 let left_predicate = left_predicate.and(left_from_right);
881 let right_predicate = right_predicate.and(right_from_left);
882
883 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
884 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
885 let new_join = LogicalJoin::with_output_indices(
886 new_left,
887 new_right,
888 join_type,
889 new_on,
890 self.output_indices().clone(),
891 );
892
893 let mut mapping = self.core.i2o_col_mapping();
894 predicate = predicate.rewrite_expr(&mut mapping);
895 LogicalFilter::create(new_join.into(), predicate)
896 }
897}
898
899#[derive(Clone, Copy)]
900struct TemporalJoinScan<'a>(&'a LogicalScan);
901
902impl<'a> Deref for TemporalJoinScan<'a> {
903 type Target = LogicalScan;
904
905 fn deref(&self) -> &Self::Target {
906 self.0
907 }
908}
909
910impl LogicalJoin {
911 fn get_stream_input_for_hash_join(
912 &self,
913 predicate: &EqJoinPredicate,
914 ctx: &mut ToStreamContext,
915 ) -> Result<(StreamPlanRef, StreamPlanRef)> {
916 use super::stream::prelude::*;
917
918 let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
919 let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
920
921 let logical_right = try_enforce_locality_requirement(self.right(), &rhs_join_key_idx);
922 let mut right = logical_right.to_stream_with_dist_required(
923 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
924 ctx,
925 )?;
926 let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
927 let r2l =
928 predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
929 let l2r =
930 predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
931 let mut left;
932 let right_dist = right.distribution();
933 match right_dist {
934 Distribution::HashShard(_) => {
935 let left_dist = r2l
936 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
937 left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
938 }
939 Distribution::UpstreamHashShard(_, _) => {
940 left = logical_left.to_stream_with_dist_required(
941 &RequiredDist::shard_by_key(
942 self.left().schema().len(),
943 &predicate.left_eq_indexes(),
944 ),
945 ctx,
946 )?;
947 let left_dist = left.distribution();
948 match left_dist {
949 Distribution::HashShard(_) => {
950 let right_dist = l2r.rewrite_required_distribution(
951 &RequiredDist::PhysicalDist(left_dist.clone()),
952 );
953 right = right_dist.streaming_enforce_if_not_satisfies(right)?
954 }
955 Distribution::UpstreamHashShard(_, _) => {
956 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
957 .streaming_enforce_if_not_satisfies(left)?;
958 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
959 .streaming_enforce_if_not_satisfies(right)?;
960 }
961 _ => unreachable!(),
962 }
963 }
964 _ => unreachable!(),
965 }
966 Ok((left, right))
967 }
968
969 fn to_stream_hash_join(
970 &self,
971 predicate: EqJoinPredicate,
972 ctx: &mut ToStreamContext,
973 ) -> Result<StreamPlanRef> {
974 use super::stream::prelude::*;
975
976 assert!(predicate.has_eq());
977 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
978
979 let core = self.core.clone_with_inputs(left, right);
980
981 let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
990
991 let force_filter_inside_join = self
992 .base
993 .ctx()
994 .session_ctx()
995 .config()
996 .streaming_force_filter_inside_join();
997
998 let pull_filter = self.join_type() == JoinType::Inner
999 && stream_hash_join.eq_join_predicate().has_non_eq()
1000 && stream_hash_join.inequality_pairs().is_empty()
1001 && (!force_filter_inside_join);
1002 if pull_filter {
1003 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1004
1005 let mut core = core;
1006 core.output_indices = default_indices.clone();
1007 let eq_cond = EqJoinPredicate::new(
1009 Condition::true_cond(),
1010 predicate.eq_keys().to_vec(),
1011 self.left().schema().len(),
1012 self.right().schema().len(),
1013 );
1014 core.on = eq_cond.eq_cond();
1015 let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
1016 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1017 let plan = StreamFilter::new(logical_filter).into();
1018 if self.output_indices() != &default_indices {
1019 let logical_project = generic::Project::with_mapping(
1020 plan,
1021 ColIndexMapping::with_remaining_columns(
1022 self.output_indices(),
1023 self.internal_column_num(),
1024 ),
1025 );
1026 Ok(StreamProject::new(logical_project).into())
1027 } else {
1028 Ok(plan)
1029 }
1030 } else {
1031 Ok(stream_hash_join.into())
1032 }
1033 }
1034
1035 fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1036 if let Some(logical_scan) = self.core.right.as_logical_scan() {
1037 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1038 .then_some(TemporalJoinScan(logical_scan))
1039 } else {
1040 None
1041 }
1042 }
1043
1044 fn should_be_stream_temporal_join<'a>(
1045 &'a self,
1046 ctx: &ToStreamContext,
1047 ) -> Result<Option<TemporalJoinScan<'a>>> {
1048 Ok(if let Some(scan) = self.temporal_join_on() {
1049 if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1050 return Err(RwError::from(ErrorCode::NotSupported(
1051 "Temporal join with snapshot backfill not supported".into(),
1052 "Please use arrangement backfill".into(),
1053 )));
1054 }
1055 if scan.cross_database() {
1056 return Err(RwError::from(ErrorCode::NotSupported(
1057 "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1058 "Please ensure both tables are in the same database".into(),
1059 )));
1060 }
1061 Some(scan)
1062 } else {
1063 None
1064 })
1065 }
1066
1067 fn to_stream_temporal_join_with_index_selection(
1068 &self,
1069 logical_scan: TemporalJoinScan<'_>,
1070 predicate: EqJoinPredicate,
1071 ctx: &mut ToStreamContext,
1072 ) -> Result<StreamPlanRef> {
1073 let mut result_plan: Result<StreamTemporalJoin> =
1075 self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1076 if let Ok(temporal_join) = &result_plan
1078 && temporal_join.eq_join_predicate().eq_indexes().len()
1079 == logical_scan.primary_key().len()
1080 {
1081 return result_plan.map(|x| x.into());
1082 }
1083 if self
1084 .core
1085 .ctx()
1086 .session_ctx()
1087 .config()
1088 .enable_index_selection()
1089 {
1090 let indexes = logical_scan.table_indexes();
1091 for index in indexes {
1092 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1094 let index_scan: PlanRef = index_scan.into();
1095 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1096 if let Ok(temporal_join) = that.to_stream_temporal_join(
1097 that.temporal_join_on().expect(
1098 "index scan created from temporal join scan must also be temporal join",
1099 ),
1100 predicate.clone(),
1101 ctx,
1102 ) {
1103 match &result_plan {
1104 Err(_) => result_plan = Ok(temporal_join),
1105 Ok(prev_temporal_join) => {
1106 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1108 < temporal_join.eq_join_predicate().eq_indexes().len()
1109 {
1110 result_plan = Ok(temporal_join)
1111 }
1112 }
1113 }
1114 }
1115 }
1116 }
1117 }
1118
1119 result_plan.map(|x| x.into())
1120 }
1121
1122 fn temporal_join_scan_predicate_pull_up(
1123 logical_scan: TemporalJoinScan<'_>,
1124 predicate: EqJoinPredicate,
1125 output_indices: &[usize],
1126 left_schema_len: usize,
1127 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1128 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1130 let o2r = if let Some(project_expr) = project_expr {
1132 project_expr
1133 .into_iter()
1134 .map(|x| x.as_input_ref().unwrap().index)
1135 .collect_vec()
1136 } else {
1137 (0..logical_scan.output_col_idx().len()).collect_vec()
1138 };
1139 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1140 offset: left_schema_len,
1141 mapping: o2r.clone(),
1142 };
1143
1144 let new_eq_cond = predicate
1145 .eq_cond()
1146 .rewrite_expr(&mut join_predicate_rewriter);
1147
1148 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1149 offset: left_schema_len,
1150 };
1151
1152 let new_other_cond = predicate
1153 .other_cond()
1154 .clone()
1155 .rewrite_expr(&mut join_predicate_rewriter)
1156 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1157
1158 let new_join_on = new_eq_cond.and(new_other_cond);
1159
1160 let new_predicate = EqJoinPredicate::create(
1161 left_schema_len,
1162 new_scan.schema().len(),
1163 new_join_on.clone(),
1164 );
1165
1166 let new_join_output_indices = output_indices
1169 .iter()
1170 .map(|&x| {
1171 if x < left_schema_len {
1172 x
1173 } else {
1174 o2r[x - left_schema_len] + left_schema_len
1175 }
1176 })
1177 .collect_vec();
1178
1179 let new_stream_table_scan =
1180 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1181 Ok((
1182 new_stream_table_scan,
1183 new_predicate,
1184 new_join_on,
1185 new_join_output_indices,
1186 ))
1187 }
1188
1189 fn to_stream_temporal_join(
1190 &self,
1191 logical_scan: TemporalJoinScan<'_>,
1192 predicate: EqJoinPredicate,
1193 ctx: &mut ToStreamContext,
1194 ) -> Result<StreamTemporalJoin> {
1195 use super::stream::prelude::*;
1196
1197 assert!(predicate.has_eq());
1198
1199 let table = logical_scan.table();
1200 let output_column_ids = logical_scan.output_column_ids();
1201
1202 let order_col_ids = table.order_column_ids();
1205 let dist_key = table.distribution_key.clone();
1206
1207 let mut dist_key_in_order_key_pos = vec![];
1208 for d in dist_key {
1209 let pos = table
1210 .order_column_indices()
1211 .position(|x| x == d)
1212 .expect("dist_key must in order_key");
1213 dist_key_in_order_key_pos.push(pos);
1214 }
1215 let shortest_prefix_len = dist_key_in_order_key_pos
1217 .iter()
1218 .max()
1219 .map_or(0, |pos| pos + 1);
1220
1221 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1223 for order_col_id in order_col_ids {
1224 let mut found = false;
1225 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1226 if order_col_id == output_column_ids[eq_idx] {
1227 reorder_idx.push(i);
1228 found = true;
1229 break;
1230 }
1231 }
1232 if !found {
1233 break;
1234 }
1235 }
1236 if reorder_idx.len() < shortest_prefix_len {
1237 return Err(RwError::from(ErrorCode::NotSupported(
1239 "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1240 "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1241 )));
1242 }
1243 let lookup_prefix_len = reorder_idx.len();
1244 let predicate = predicate.reorder(&reorder_idx);
1245
1246 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1247 RequiredDist::single()
1248 } else {
1249 let left_eq_indexes = predicate.left_eq_indexes();
1250 let left_dist_key = dist_key_in_order_key_pos
1251 .iter()
1252 .map(|pos| left_eq_indexes[*pos])
1253 .collect_vec();
1254
1255 RequiredDist::hash_shard(&left_dist_key)
1256 };
1257
1258 let lhs_join_key_idx = predicate
1259 .eq_indexes()
1260 .into_iter()
1261 .map(|(l, _)| l)
1262 .collect_vec();
1263 let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
1264 let left = logical_left.to_stream(ctx)?;
1265 let left = required_dist.stream_enforce(left);
1267
1268 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1269 Self::temporal_join_scan_predicate_pull_up(
1270 logical_scan,
1271 predicate,
1272 self.output_indices(),
1273 self.left().schema().len(),
1274 )?;
1275
1276 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1277 if !new_predicate.has_eq() {
1278 return Err(RwError::from(ErrorCode::NotSupported(
1279 "Temporal join requires a non trivial join condition".into(),
1280 "Please remove the false condition of the join".into(),
1281 )));
1282 }
1283
1284 let new_logical_join = generic::Join::new(
1286 left,
1287 right,
1288 new_join_on,
1289 self.join_type(),
1290 new_join_output_indices,
1291 );
1292
1293 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1294
1295 StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1296 }
1297
1298 fn to_stream_nested_loop_temporal_join(
1299 &self,
1300 logical_scan: TemporalJoinScan<'_>,
1301 predicate: EqJoinPredicate,
1302 ctx: &mut ToStreamContext,
1303 ) -> Result<StreamPlanRef> {
1304 use super::stream::prelude::*;
1305 assert!(!predicate.has_eq());
1306
1307 let left = self.left().to_stream_with_dist_required(
1308 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1309 ctx,
1310 )?;
1311 assert!(left.as_stream_exchange().is_some());
1312
1313 if self.join_type() != JoinType::Inner {
1314 return Err(RwError::from(ErrorCode::NotSupported(
1315 "Temporal join requires an inner join".into(),
1316 "Please use an inner join".into(),
1317 )));
1318 }
1319
1320 if !left.append_only() {
1321 return Err(RwError::from(ErrorCode::NotSupported(
1322 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1323 "Please ensure the left hash side is append only".into(),
1324 )));
1325 }
1326
1327 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1328 Self::temporal_join_scan_predicate_pull_up(
1329 logical_scan,
1330 predicate,
1331 self.output_indices(),
1332 self.left().schema().len(),
1333 )?;
1334
1335 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1336
1337 let new_logical_join = generic::Join::new(
1339 left,
1340 right,
1341 new_join_on,
1342 self.join_type(),
1343 new_join_output_indices,
1344 );
1345
1346 Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1347 }
1348
1349 fn to_stream_dynamic_filter(
1350 &self,
1351 predicate: Condition,
1352 ctx: &mut ToStreamContext,
1353 ) -> Result<Option<StreamPlanRef>> {
1354 use super::stream::prelude::*;
1355
1356 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1362 return Ok(None);
1363 }
1364
1365 if !self.right().max_one_row() {
1367 return Ok(None);
1368 }
1369 if self.right().schema().len() != 1 {
1370 return Ok(None);
1371 }
1372
1373 if predicate.conjunctions.len() > 1 {
1375 return Ok(None);
1376 }
1377 let expr: ExprImpl = predicate.into();
1378 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1379 Some(v) => v,
1380 None => return Ok(None),
1381 };
1382
1383 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1384 && right_ref.index == self.left().schema().len() ;
1385 if !condition_cross_inputs {
1386 return Ok(None);
1388 }
1389
1390 if self.left().schema().fields()[left_ref.index].data_type
1392 != self.right().schema().fields()[0].data_type
1393 {
1394 return Ok(None);
1395 }
1396
1397 let all_output_from_left = self
1399 .output_indices()
1400 .iter()
1401 .all(|i| *i < self.left().schema().len());
1402 if !all_output_from_left {
1403 return Ok(None);
1404 }
1405
1406 let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1407 let right = self.right().to_stream_with_dist_required(
1408 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1409 ctx,
1410 )?;
1411
1412 assert!(right.as_stream_exchange().is_some());
1413 assert_eq!(
1414 *right.inputs().iter().exactly_one().unwrap().distribution(),
1415 Distribution::Single
1416 );
1417
1418 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1419 let plan = StreamDynamicFilter::new(core)?.into();
1420 if self
1422 .output_indices()
1423 .iter()
1424 .copied()
1425 .ne(0..self.left().schema().len())
1426 {
1427 let logical_project = generic::Project::with_mapping(
1430 plan,
1431 ColIndexMapping::with_remaining_columns(
1432 self.output_indices(),
1433 self.left().schema().len(),
1434 ),
1435 );
1436 Ok(Some(StreamProject::new(logical_project).into()))
1437 } else {
1438 Ok(Some(plan))
1439 }
1440 }
1441
1442 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1443 let predicate = EqJoinPredicate::create(
1444 self.left().schema().len(),
1445 self.right().schema().len(),
1446 self.on().clone(),
1447 );
1448 assert!(predicate.has_eq());
1449
1450 let join = self
1451 .core
1452 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1453
1454 Ok(self
1455 .to_batch_lookup_join(predicate, join)?
1456 .expect("Fail to convert to lookup join")
1457 .into())
1458 }
1459
1460 fn to_stream_asof_join(
1461 &self,
1462 predicate: EqJoinPredicate,
1463 ctx: &mut ToStreamContext,
1464 ) -> Result<StreamPlanRef> {
1465 use super::stream::prelude::*;
1466
1467 if predicate.eq_keys().is_empty() {
1468 return Err(ErrorCode::InvalidInputSyntax(
1469 "AsOf join requires at least 1 equal condition".to_owned(),
1470 )
1471 .into());
1472 }
1473
1474 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1475 let left_len = left.schema().len();
1476 let core = self.core.clone_with_inputs(left, right);
1477
1478 let inequality_desc =
1479 Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1480
1481 Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1482 }
1483
1484 fn to_batch_hash_join(
1486 &self,
1487 logical_join: generic::Join<BatchPlanRef>,
1488 predicate: EqJoinPredicate,
1489 ) -> Result<BatchPlanRef> {
1490 use super::batch::prelude::*;
1491
1492 let left_schema_len = logical_join.left.schema().len();
1493 let asof_desc = self
1494 .is_asof_join()
1495 .then(|| {
1496 Self::get_inequality_desc_from_predicate(
1497 predicate.other_cond().clone(),
1498 left_schema_len,
1499 )
1500 })
1501 .transpose()?;
1502
1503 let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1504 Ok(batch_join.into())
1505 }
1506
1507 pub fn get_inequality_desc_from_predicate(
1508 predicate: Condition,
1509 left_input_len: usize,
1510 ) -> Result<AsOfJoinDesc> {
1511 let expr: ExprImpl = predicate.into();
1512 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1513 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1514 {
1515 Ok(AsOfJoinDesc {
1516 left_idx: left_input_ref.index() as u32,
1517 right_idx: (right_input_ref.index() - left_input_len) as u32,
1518 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1519 })
1520 } else {
1521 bail!("inequal condition from the same side should be push down in optimizer");
1522 }
1523 } else {
1524 Err(ErrorCode::InvalidInputSyntax(
1525 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1526 )
1527 .into())
1528 }
1529 }
1530
1531 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1532 match expr_type {
1533 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1534 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1535 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1536 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1537 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1538 "Invalid comparison type: {}",
1539 expr_type.as_str_name()
1540 ))
1541 .into()),
1542 }
1543 }
1544}
1545
1546impl ToBatch for LogicalJoin {
1547 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1548 let predicate = EqJoinPredicate::create(
1549 self.left().schema().len(),
1550 self.right().schema().len(),
1551 self.on().clone(),
1552 );
1553
1554 let batch_join = self
1555 .core
1556 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1557
1558 let ctx = self.base.ctx();
1559 let config = ctx.session_ctx().config();
1560
1561 if predicate.has_eq() {
1562 if !predicate.eq_keys_are_type_aligned() {
1563 return Err(ErrorCode::InternalError(format!(
1564 "Join eq keys are not aligned for predicate: {predicate:?}"
1565 ))
1566 .into());
1567 }
1568 if config.batch_enable_lookup_join()
1569 && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1570 predicate.clone(),
1571 batch_join.clone(),
1572 )?
1573 {
1574 return Ok(lookup_join.into());
1575 }
1576 self.to_batch_hash_join(batch_join, predicate)
1577 } else if self.is_asof_join() {
1578 Err(ErrorCode::InvalidInputSyntax(
1579 "AsOf join requires at least 1 equal condition".to_owned(),
1580 )
1581 .into())
1582 } else {
1583 Ok(BatchNestedLoopJoin::new(batch_join).into())
1585 }
1586 }
1587}
1588
1589impl ToStream for LogicalJoin {
1590 fn to_stream(
1591 &self,
1592 ctx: &mut ToStreamContext,
1593 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1594 if self
1595 .on()
1596 .conjunctions
1597 .iter()
1598 .any(|cond| cond.count_nows() > 0)
1599 {
1600 return Err(ErrorCode::NotSupported(
1601 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1602 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1603 }
1604
1605 let predicate = EqJoinPredicate::create(
1606 self.left().schema().len(),
1607 self.right().schema().len(),
1608 self.on().clone(),
1609 );
1610
1611 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1612 self.to_stream_asof_join(predicate, ctx)
1613 } else if predicate.has_eq() {
1614 if !predicate.eq_keys_are_type_aligned() {
1615 return Err(ErrorCode::InternalError(format!(
1616 "Join eq keys are not aligned for predicate: {predicate:?}"
1617 ))
1618 .into());
1619 }
1620
1621 if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1622 self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1623 } else {
1624 self.to_stream_hash_join(predicate, ctx)
1625 }
1626 } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1627 self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1628 } else if let Some(dynamic_filter) =
1629 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1630 {
1631 Ok(dynamic_filter)
1632 } else {
1633 Err(RwError::from(ErrorCode::NotSupported(
1634 "streaming nested-loop join".to_owned(),
1635 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1636 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1637 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1638 )))
1639 }
1640 }
1641
1642 fn logical_rewrite_for_stream(
1643 &self,
1644 ctx: &mut RewriteStreamContext,
1645 ) -> Result<(PlanRef, ColIndexMapping)> {
1646 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1647 let left_len = left.schema().len();
1648 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1649 let (join, out_col_change) = self.rewrite_with_left_right(
1650 left.clone(),
1651 left_col_change,
1652 right.clone(),
1653 right_col_change,
1654 );
1655
1656 let mapping = ColIndexMapping::with_remaining_columns(
1657 join.output_indices(),
1658 join.internal_column_num(),
1659 );
1660
1661 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1662 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1663
1664 let mut left_to_add = left
1666 .expect_stream_key()
1667 .iter()
1668 .cloned()
1669 .filter(|i| l2o.try_map(*i).is_none())
1670 .collect_vec();
1671
1672 let mut right_to_add = right
1673 .expect_stream_key()
1674 .iter()
1675 .filter(|&&i| r2o.try_map(i).is_none())
1676 .map(|&i| i + left_len)
1677 .collect_vec();
1678
1679 let right_len = right.schema().len();
1682 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1683
1684 let either_or_both = self.core.add_which_join_key_to_pk();
1685
1686 for (lk, rk) in eq_predicate.eq_indexes() {
1687 match either_or_both {
1688 EitherOrBoth::Left(_) => {
1689 if l2o.try_map(lk).is_none() {
1690 left_to_add.push(lk);
1691 }
1692 }
1693 EitherOrBoth::Right(_) => {
1694 if r2o.try_map(rk).is_none() {
1695 right_to_add.push(rk + left_len)
1696 }
1697 }
1698 EitherOrBoth::Both(_, _) => {
1699 if l2o.try_map(lk).is_none() {
1700 left_to_add.push(lk);
1701 }
1702 if r2o.try_map(rk).is_none() {
1703 right_to_add.push(rk + left_len)
1704 }
1705 }
1706 };
1707 }
1708 let left_to_add = left_to_add.into_iter().unique();
1709 let right_to_add = right_to_add.into_iter().unique();
1710 let mut new_output_indices = join.output_indices().clone();
1713 if !join.is_right_join() {
1714 new_output_indices.extend(left_to_add);
1715 }
1716 if !join.is_left_join() {
1717 new_output_indices.extend(right_to_add);
1718 }
1719
1720 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1721
1722 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1723 let l2o = join_with_pk
1726 .core
1727 .l2i_col_mapping()
1728 .composite(&join_with_pk.core.i2o_col_mapping());
1729 let r2o = join_with_pk
1730 .core
1731 .r2i_col_mapping()
1732 .composite(&join_with_pk.core.i2o_col_mapping());
1733 let left_right_stream_keys = join_with_pk
1734 .left()
1735 .expect_stream_key()
1736 .iter()
1737 .map(|i| l2o.map(*i))
1738 .chain(
1739 join_with_pk
1740 .right()
1741 .expect_stream_key()
1742 .iter()
1743 .map(|i| r2o.map(*i)),
1744 )
1745 .collect_vec();
1746 let plan: PlanRef = join_with_pk.into();
1747 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1748 } else {
1749 join_with_pk.into()
1750 };
1751
1752 Ok((plan, out_col_change))
1754 }
1755
1756 fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1757 let mut ctx = ToStreamContext::new(false);
1758 if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1760 let o2i_mapping = self.core.o2i_col_mapping();
1762 let left_input_columns = columns
1763 .iter()
1764 .map(|&col| o2i_mapping.try_map(col))
1765 .collect::<Option<Vec<usize>>>()?;
1766 if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1767 return Some(
1768 self.clone_with_left_right(better_left_plan, self.right())
1769 .into(),
1770 );
1771 }
1772 }
1773 None
1774 }
1775}
1776
1777#[cfg(test)]
1778mod tests {
1779
1780 use std::collections::HashSet;
1781
1782 use risingwave_common::catalog::{Field, Schema};
1783 use risingwave_common::types::{DataType, Datum};
1784 use risingwave_pb::expr::expr_node::Type;
1785
1786 use super::*;
1787 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1788 use crate::optimizer::optimizer_context::OptimizerContext;
1789 use crate::optimizer::plan_node::LogicalValues;
1790 use crate::optimizer::property::FunctionalDependency;
1791
1792 #[tokio::test]
1806 async fn test_prune_join() {
1807 let ty = DataType::Int32;
1808 let ctx = OptimizerContext::mock().await;
1809 let fields: Vec<Field> = (1..7)
1810 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1811 .collect();
1812 let left = LogicalValues::new(
1813 vec![],
1814 Schema {
1815 fields: fields[0..3].to_vec(),
1816 },
1817 ctx.clone(),
1818 );
1819 let right = LogicalValues::new(
1820 vec![],
1821 Schema {
1822 fields: fields[3..6].to_vec(),
1823 },
1824 ctx,
1825 );
1826 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1827 FunctionCall::new(
1828 Type::Equal,
1829 vec![
1830 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1831 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1832 ],
1833 )
1834 .unwrap(),
1835 ));
1836 let join_type = JoinType::Inner;
1837 let join: PlanRef = LogicalJoin::new(
1838 left.into(),
1839 right.into(),
1840 join_type,
1841 Condition::with_expr(on),
1842 )
1843 .into();
1844
1845 let required_cols = vec![2, 3];
1847 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1848
1849 let join = plan.as_logical_join().unwrap();
1851 assert_eq!(join.schema().fields().len(), 2);
1852 assert_eq!(join.schema().fields()[0], fields[2]);
1853 assert_eq!(join.schema().fields()[1], fields[3]);
1854
1855 let expr: ExprImpl = join.on().clone().into();
1856 let call = expr.as_function_call().unwrap();
1857 assert_eq_input_ref!(&call.inputs()[0], 0);
1858 assert_eq_input_ref!(&call.inputs()[1], 2);
1859
1860 let left = join.left();
1861 let left = left.as_logical_values().unwrap();
1862 assert_eq!(left.schema().fields(), &fields[1..3]);
1863 let right = join.right();
1864 let right = right.as_logical_values().unwrap();
1865 assert_eq!(right.schema().fields(), &fields[3..4]);
1866 }
1867
1868 #[tokio::test]
1870 async fn test_prune_semi_join() {
1871 let ty = DataType::Int32;
1872 let ctx = OptimizerContext::mock().await;
1873 let fields: Vec<Field> = (1..7)
1874 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1875 .collect();
1876 let left = LogicalValues::new(
1877 vec![],
1878 Schema {
1879 fields: fields[0..3].to_vec(),
1880 },
1881 ctx.clone(),
1882 );
1883 let right = LogicalValues::new(
1884 vec![],
1885 Schema {
1886 fields: fields[3..6].to_vec(),
1887 },
1888 ctx,
1889 );
1890 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1891 FunctionCall::new(
1892 Type::Equal,
1893 vec![
1894 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1895 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1896 ],
1897 )
1898 .unwrap(),
1899 ));
1900 for join_type in [
1901 JoinType::LeftSemi,
1902 JoinType::RightSemi,
1903 JoinType::LeftAnti,
1904 JoinType::RightAnti,
1905 ] {
1906 let join = LogicalJoin::new(
1907 left.clone().into(),
1908 right.clone().into(),
1909 join_type,
1910 Condition::with_expr(on.clone()),
1911 );
1912
1913 let offset = if join.is_right_join() { 3 } else { 0 };
1914 let join: PlanRef = join.into();
1915 let required_cols = vec![0];
1917 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1919 let as_plan = plan.as_logical_join().unwrap();
1920 assert_eq!(as_plan.schema().fields().len(), 1);
1922 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1923
1924 let required_cols = vec![0, 1, 2];
1926 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1928 let as_plan = plan.as_logical_join().unwrap();
1929 assert_eq!(as_plan.schema().fields().len(), 3);
1931 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1932 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1933 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1934 }
1935 }
1936
1937 #[tokio::test]
1950 async fn test_prune_join_no_project() {
1951 let ty = DataType::Int32;
1952 let ctx = OptimizerContext::mock().await;
1953 let fields: Vec<Field> = (1..7)
1954 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1955 .collect();
1956 let left = LogicalValues::new(
1957 vec![],
1958 Schema {
1959 fields: fields[0..3].to_vec(),
1960 },
1961 ctx.clone(),
1962 );
1963 let right = LogicalValues::new(
1964 vec![],
1965 Schema {
1966 fields: fields[3..6].to_vec(),
1967 },
1968 ctx,
1969 );
1970 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1971 FunctionCall::new(
1972 Type::Equal,
1973 vec![
1974 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1975 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1976 ],
1977 )
1978 .unwrap(),
1979 ));
1980 let join_type = JoinType::Inner;
1981 let join: PlanRef = LogicalJoin::new(
1982 left.into(),
1983 right.into(),
1984 join_type,
1985 Condition::with_expr(on),
1986 )
1987 .into();
1988
1989 let required_cols = vec![1, 3];
1991 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1992
1993 let join = plan.as_logical_join().unwrap();
1995 assert_eq!(join.schema().fields().len(), 2);
1996 assert_eq!(join.schema().fields()[0], fields[1]);
1997 assert_eq!(join.schema().fields()[1], fields[3]);
1998
1999 let expr: ExprImpl = join.on().clone().into();
2000 let call = expr.as_function_call().unwrap();
2001 assert_eq_input_ref!(&call.inputs()[0], 0);
2002 assert_eq_input_ref!(&call.inputs()[1], 1);
2003
2004 let left = join.left();
2005 let left = left.as_logical_values().unwrap();
2006 assert_eq!(left.schema().fields(), &fields[1..2]);
2007 let right = join.right();
2008 let right = right.as_logical_values().unwrap();
2009 assert_eq!(right.schema().fields(), &fields[3..4]);
2010 }
2011
2012 #[tokio::test]
2026 async fn test_join_to_batch() {
2027 let ctx = OptimizerContext::mock().await;
2028 let fields: Vec<Field> = (1..7)
2029 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2030 .collect();
2031 let left = LogicalValues::new(
2032 vec![],
2033 Schema {
2034 fields: fields[0..3].to_vec(),
2035 },
2036 ctx.clone(),
2037 );
2038 let right = LogicalValues::new(
2039 vec![],
2040 Schema {
2041 fields: fields[3..6].to_vec(),
2042 },
2043 ctx,
2044 );
2045
2046 fn input_ref(i: usize) -> ExprImpl {
2047 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2048 }
2049 let eq_cond = ExprImpl::FunctionCall(Box::new(
2050 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2051 ));
2052 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2053 FunctionCall::new(
2054 Type::Equal,
2055 vec![
2056 input_ref(2),
2057 ExprImpl::Literal(Box::new(Literal::new(
2058 Datum::Some(42_i32.into()),
2059 DataType::Int32,
2060 ))),
2061 ],
2062 )
2063 .unwrap(),
2064 ));
2065 let on_cond = ExprImpl::FunctionCall(Box::new(
2067 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2068 ));
2069
2070 let join_type = JoinType::Inner;
2071 let logical_join = LogicalJoin::new(
2072 left.into(),
2073 right.into(),
2074 join_type,
2075 Condition::with_expr(on_cond),
2076 );
2077
2078 let result = logical_join.to_batch().unwrap();
2080
2081 let hash_join = result.as_batch_hash_join().unwrap();
2083 assert_eq!(
2084 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2085 eq_cond
2086 );
2087 assert_eq!(
2088 *hash_join
2089 .eq_join_predicate()
2090 .non_eq_cond()
2091 .conjunctions
2092 .first()
2093 .unwrap(),
2094 non_eq_cond
2095 );
2096 }
2097
2098 #[tokio::test]
2111 #[ignore] async fn test_join_to_stream() {
2114 }
2182 #[tokio::test]
2196 async fn test_join_column_prune_with_order_required() {
2197 let ty = DataType::Int32;
2198 let ctx = OptimizerContext::mock().await;
2199 let fields: Vec<Field> = (1..7)
2200 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2201 .collect();
2202 let left = LogicalValues::new(
2203 vec![],
2204 Schema {
2205 fields: fields[0..3].to_vec(),
2206 },
2207 ctx.clone(),
2208 );
2209 let right = LogicalValues::new(
2210 vec![],
2211 Schema {
2212 fields: fields[3..6].to_vec(),
2213 },
2214 ctx,
2215 );
2216 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2217 FunctionCall::new(
2218 Type::Equal,
2219 vec![
2220 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2221 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2222 ],
2223 )
2224 .unwrap(),
2225 ));
2226 let join_type = JoinType::Inner;
2227 let join: PlanRef = LogicalJoin::new(
2228 left.into(),
2229 right.into(),
2230 join_type,
2231 Condition::with_expr(on),
2232 )
2233 .into();
2234
2235 let required_cols = vec![3, 2];
2237 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2238
2239 let join = plan.as_logical_join().unwrap();
2241 assert_eq!(join.schema().fields().len(), 2);
2242 assert_eq!(join.schema().fields()[0], fields[3]);
2243 assert_eq!(join.schema().fields()[1], fields[2]);
2244
2245 let expr: ExprImpl = join.on().clone().into();
2246 let call = expr.as_function_call().unwrap();
2247 assert_eq_input_ref!(&call.inputs()[0], 0);
2248 assert_eq_input_ref!(&call.inputs()[1], 2);
2249
2250 let left = join.left();
2251 let left = left.as_logical_values().unwrap();
2252 assert_eq!(left.schema().fields(), &fields[1..3]);
2253 let right = join.right();
2254 let right = right.as_logical_values().unwrap();
2255 assert_eq!(right.schema().fields(), &fields[3..4]);
2256 }
2257
2258 #[tokio::test]
2259 async fn fd_derivation_inner_outer_join() {
2260 let ctx = OptimizerContext::mock().await;
2283 let left = {
2284 let fields: Vec<Field> = vec![
2285 Field::with_name(DataType::Int32, "l0"),
2286 Field::with_name(DataType::Int32, "l1"),
2287 ];
2288 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2289 values
2291 .base
2292 .functional_dependency_mut()
2293 .add_functional_dependency_by_column_indices(&[0], &[1]);
2294 values
2295 };
2296 let right = {
2297 let fields: Vec<Field> = vec![
2298 Field::with_name(DataType::Int32, "r0"),
2299 Field::with_name(DataType::Int32, "r1"),
2300 Field::with_name(DataType::Int32, "r2"),
2301 ];
2302 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2303 values
2305 .base
2306 .functional_dependency_mut()
2307 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2308 values
2309 };
2310 let on: ExprImpl = FunctionCall::new(
2312 Type::And,
2313 vec![
2314 FunctionCall::new(
2315 Type::Equal,
2316 vec![
2317 InputRef::new(0, DataType::Int32).into(),
2318 ExprImpl::literal_int(0),
2319 ],
2320 )
2321 .unwrap()
2322 .into(),
2323 FunctionCall::new(
2324 Type::Equal,
2325 vec![
2326 InputRef::new(1, DataType::Int32).into(),
2327 InputRef::new(3, DataType::Int32).into(),
2328 ],
2329 )
2330 .unwrap()
2331 .into(),
2332 ],
2333 )
2334 .unwrap()
2335 .into();
2336 let expected_fd_set = [
2337 (
2338 JoinType::Inner,
2339 [
2340 FunctionalDependency::with_indices(5, &[0], &[1]),
2342 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2344 FunctionalDependency::with_indices(5, &[], &[0]),
2346 FunctionalDependency::with_indices(5, &[1], &[3]),
2348 FunctionalDependency::with_indices(5, &[3], &[1]),
2349 ]
2350 .into_iter()
2351 .collect::<HashSet<_>>(),
2352 ),
2353 (JoinType::FullOuter, HashSet::new()),
2354 (
2355 JoinType::RightOuter,
2356 [
2357 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2359 ]
2360 .into_iter()
2361 .collect::<HashSet<_>>(),
2362 ),
2363 (
2364 JoinType::LeftOuter,
2365 [
2366 FunctionalDependency::with_indices(5, &[0], &[1]),
2368 ]
2369 .into_iter()
2370 .collect::<HashSet<_>>(),
2371 ),
2372 (
2373 JoinType::LeftSemi,
2374 [
2375 FunctionalDependency::with_indices(2, &[0], &[1]),
2377 ]
2378 .into_iter()
2379 .collect::<HashSet<_>>(),
2380 ),
2381 (
2382 JoinType::LeftAnti,
2383 [
2384 FunctionalDependency::with_indices(2, &[0], &[1]),
2386 ]
2387 .into_iter()
2388 .collect::<HashSet<_>>(),
2389 ),
2390 (
2391 JoinType::RightSemi,
2392 [
2393 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2395 ]
2396 .into_iter()
2397 .collect::<HashSet<_>>(),
2398 ),
2399 (
2400 JoinType::RightAnti,
2401 [
2402 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2404 ]
2405 .into_iter()
2406 .collect::<HashSet<_>>(),
2407 ),
2408 ];
2409
2410 for (join_type, expected_res) in expected_fd_set {
2411 let join = LogicalJoin::new(
2412 left.clone().into(),
2413 right.clone().into(),
2414 join_type,
2415 Condition::with_expr(on.clone()),
2416 );
2417 let fd_set = join
2418 .functional_dependency()
2419 .as_dependencies()
2420 .iter()
2421 .cloned()
2422 .collect::<HashSet<_>>();
2423 assert_eq!(fd_set, expected_res);
2424 }
2425 }
2426}