1use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32 BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33 PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34 ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59 pub base: PlanBase<Logical>,
60 core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64 fn distill<'a>(&self) -> XmlNode<'a> {
65 let verbose = self.base.ctx().is_explain_verbose();
66 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67 vec.push(("type", Pretty::debug(&self.join_type())));
68
69 let concat_schema = self.core.concat_schema();
70 let cond = Pretty::debug(&ConditionDisplay {
71 condition: self.on(),
72 input_schema: &concat_schema,
73 });
74 vec.push(("on", cond));
75
76 if verbose {
77 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78 vec.push(("output", data));
79 }
80
81 childless_record("LogicalJoin", vec)
82 }
83}
84
85impl LogicalJoin {
86 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87 let core = generic::Join::with_full_output(left, right, join_type, on);
88 Self::with_core(core)
89 }
90
91 pub(crate) fn with_output_indices(
92 left: PlanRef,
93 right: PlanRef,
94 join_type: JoinType,
95 on: Condition,
96 output_indices: Vec<usize>,
97 ) -> Self {
98 let core = generic::Join::new(left, right, on, join_type, output_indices);
99 Self::with_core(core)
100 }
101
102 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103 let base = PlanBase::new_logical_with_core(&core);
104 LogicalJoin { base, core }
105 }
106
107 pub fn create(
108 left: PlanRef,
109 right: PlanRef,
110 join_type: JoinType,
111 on_clause: ExprImpl,
112 ) -> PlanRef {
113 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114 }
115
116 pub fn internal_column_num(&self) -> usize {
117 self.core.internal_column_num()
118 }
119
120 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121 self.core.i2l_col_mapping_ignore_join_type()
122 }
123
124 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125 self.core.i2r_col_mapping_ignore_join_type()
126 }
127
128 pub fn on(&self) -> &Condition {
130 self.core
131 .on
132 .as_condition_ref()
133 .expect("logical join should store predicate as Condition")
134 }
135
136 pub fn core(&self) -> &generic::Join<PlanRef> {
137 &self.core
138 }
139
140 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
142 let input_refs = self
143 .core
144 .on
145 .as_condition_ref()
146 .expect("logical join should store predicate as Condition")
147 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
148 let index_group = input_refs
149 .ones()
150 .chunk_by(|i| *i < self.core.left.schema().len());
151 let left_index = index_group
152 .into_iter()
153 .next()
154 .map_or(vec![], |group| group.1.collect_vec());
155 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
156 group
157 .1
158 .map(|i| i - self.core.left.schema().len())
159 .collect_vec()
160 });
161 (left_index, right_index)
162 }
163
164 pub fn join_type(&self) -> JoinType {
166 self.core.join_type
167 }
168
169 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
171 self.core.eq_indexes()
172 }
173
174 pub fn output_indices(&self) -> &Vec<usize> {
176 &self.core.output_indices
177 }
178
179 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
181 Self::with_core(generic::Join {
182 output_indices,
183 ..self.core.clone()
184 })
185 }
186
187 pub fn clone_with_cond(&self, on: Condition) -> Self {
189 Self::with_core(generic::Join {
190 on: generic::JoinOn::Condition(on),
191 ..self.core.clone()
192 })
193 }
194
195 pub fn is_left_join(&self) -> bool {
196 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
197 }
198
199 pub fn is_right_join(&self) -> bool {
200 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
201 }
202
203 pub fn is_full_out(&self) -> bool {
204 self.core.is_full_out()
205 }
206
207 pub fn is_asof_join(&self) -> bool {
208 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
209 }
210
211 pub fn output_indices_are_trivial(&self) -> bool {
212 itertools::equal(
213 self.output_indices().iter().cloned(),
214 0..self.internal_column_num(),
215 )
216 }
217
218 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
223 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
224 JoinType::LeftOuter => (false, true),
225 JoinType::RightOuter => (true, false),
226 JoinType::FullOuter => (true, true),
227 _ => return join_type,
228 };
229
230 for expr in &predicate.conjunctions {
231 if let ExprImpl::FunctionCall(func) = expr {
232 match func.func_type() {
233 ExprType::Equal
234 | ExprType::NotEqual
235 | ExprType::LessThan
236 | ExprType::LessThanOrEqual
237 | ExprType::GreaterThan
238 | ExprType::GreaterThanOrEqual => {
239 for input in func.inputs() {
240 if let ExprImpl::InputRef(input) = input {
241 let idx = input.index;
242 if idx < left_col_num {
243 gen_null_in_left = false;
244 } else {
245 gen_null_in_right = false;
246 }
247 }
248 }
249 }
250 _ => {}
251 };
252 }
253 }
254
255 match (gen_null_in_left, gen_null_in_right) {
256 (true, true) => JoinType::FullOuter,
257 (true, false) => JoinType::RightOuter,
258 (false, true) => JoinType::LeftOuter,
259 (false, false) => JoinType::Inner,
260 }
261 }
262
263 fn to_batch_lookup_join_with_index_selection(
267 &self,
268 predicate: EqJoinPredicate,
269 batch_join: generic::Join<BatchPlanRef>,
270 ) -> Result<Option<BatchLookupJoin>> {
271 match batch_join.join_type {
272 JoinType::Inner
273 | JoinType::LeftOuter
274 | JoinType::LeftSemi
275 | JoinType::LeftAnti
276 | JoinType::AsofInner
277 | JoinType::AsofLeftOuter => {}
278 _ => return Ok(None),
279 };
280
281 let right = self.right();
283 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
285 logical_scan
286 } else {
287 return Ok(None);
288 };
289
290 let mut result_plan = None;
291 if let Some(lookup_join) =
293 self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
294 {
295 result_plan = Some(lookup_join);
296 }
297
298 if self
299 .core
300 .ctx()
301 .session_ctx()
302 .config()
303 .enable_index_selection()
304 {
305 let indexes = logical_scan.table_indexes();
306 for index in indexes {
307 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
308 let index_scan: PlanRef = index_scan.into();
309 let that = self.clone_with_left_right(self.left(), index_scan.clone());
310 let mut new_batch_join = batch_join.clone();
311 new_batch_join.right =
312 index_scan.to_batch().expect("index scan failed to batch");
313
314 if let Some(lookup_join) =
316 that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
317 {
318 match &result_plan {
319 None => result_plan = Some(lookup_join),
320 Some(prev_lookup_join) => {
321 if prev_lookup_join.lookup_prefix_len()
323 < lookup_join.lookup_prefix_len()
324 {
325 result_plan = Some(lookup_join)
326 }
327 }
328 }
329 }
330 }
331 }
332 }
333
334 Ok(result_plan)
335 }
336
337 fn to_batch_lookup_join(
339 &self,
340 predicate: EqJoinPredicate,
341 logical_join: generic::Join<BatchPlanRef>,
342 ) -> Result<Option<BatchLookupJoin>> {
343 let logical_scan: &LogicalScan =
344 if let Some(logical_scan) = self.core.right.as_logical_scan() {
345 logical_scan
346 } else {
347 return Ok(None);
348 };
349 Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
350 }
351
352 pub fn gen_batch_lookup_join(
353 logical_scan: &LogicalScan,
354 predicate: EqJoinPredicate,
355 logical_join: generic::Join<BatchPlanRef>,
356 is_as_of: bool,
357 ) -> Result<Option<BatchLookupJoin>> {
358 match logical_join.join_type {
359 JoinType::Inner
360 | JoinType::LeftOuter
361 | JoinType::LeftSemi
362 | JoinType::LeftAnti
363 | JoinType::AsofInner
364 | JoinType::AsofLeftOuter => {}
365 _ => return Ok(None),
366 };
367
368 let table = logical_scan.table();
369 let output_column_ids = logical_scan.output_column_ids();
370
371 let order_col_ids = table.order_column_ids();
374 let dist_key = table.distribution_key.clone();
375 let mut dist_key_in_order_key_pos = vec![];
377 for d in dist_key {
378 let pos = table
379 .order_column_indices()
380 .position(|x| x == d)
381 .expect("dist_key must in order_key");
382 dist_key_in_order_key_pos.push(pos);
383 }
384 let shortest_prefix_len = dist_key_in_order_key_pos
386 .iter()
387 .max()
388 .map_or(0, |pos| pos + 1);
389
390 if shortest_prefix_len == 0 {
392 return Ok(None);
393 }
394
395 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
397 for order_col_id in order_col_ids {
398 let mut found = false;
399 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
400 if order_col_id == output_column_ids[eq_idx] {
401 reorder_idx.push(i);
402 found = true;
403 break;
404 }
405 }
406 if !found {
407 break;
408 }
409 }
410 if reorder_idx.len() < shortest_prefix_len {
411 return Ok(None);
412 }
413 let lookup_prefix_len = reorder_idx.len();
414 let predicate = predicate.reorder(&reorder_idx);
415
416 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
418 let o2r = if let Some(project_expr) = project_expr {
420 project_expr
421 .into_iter()
422 .map(|x| x.as_input_ref().unwrap().index)
423 .collect_vec()
424 } else {
425 (0..logical_scan.output_col_idx().len()).collect_vec()
426 };
427 let left_schema_len = logical_join.left.schema().len();
428
429 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
430 offset: left_schema_len,
431 mapping: o2r.clone(),
432 };
433
434 let new_eq_cond = predicate
435 .eq_cond()
436 .rewrite_expr(&mut join_predicate_rewriter);
437
438 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
439 offset: left_schema_len,
440 };
441
442 let new_other_cond = predicate
443 .other_cond()
444 .clone()
445 .rewrite_expr(&mut join_predicate_rewriter)
446 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
447
448 let new_join_on = new_eq_cond.and(new_other_cond);
449 let new_predicate =
450 EqJoinPredicate::create(left_schema_len, new_scan.schema().len(), new_join_on);
451
452 if !new_predicate.has_eq() {
455 return Ok(None);
456 }
457
458 let new_join_output_indices = logical_join
461 .output_indices
462 .iter()
463 .map(|&x| {
464 if x < left_schema_len {
465 x
466 } else {
467 o2r[x - left_schema_len] + left_schema_len
468 }
469 })
470 .collect_vec();
471
472 let new_scan_output_column_ids = new_scan.output_column_ids();
473 let as_of = new_scan.as_of.clone();
474 let new_logical_scan: LogicalScan = new_scan.into();
475
476 let new_logical_join = generic::Join::new_with_eq_predicate(
478 logical_join.left,
479 new_logical_scan.to_batch()?,
480 new_predicate,
481 logical_join.join_type,
482 new_join_output_indices,
483 );
484
485 let asof_desc = is_as_of
486 .then(|| {
487 Self::get_inequality_desc_from_predicate(
488 predicate.other_cond().clone(),
489 left_schema_len,
490 )
491 })
492 .transpose()?;
493
494 Ok(Some(BatchLookupJoin::new(
495 new_logical_join,
496 table.clone(),
497 new_scan_output_column_ids,
498 lookup_prefix_len,
499 false,
500 as_of,
501 asof_desc,
502 )))
503 }
504
505 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
506 self.core.decompose()
507 }
508}
509
510impl PlanTreeNodeBinary<Logical> for LogicalJoin {
511 fn left(&self) -> PlanRef {
512 self.core.left.clone()
513 }
514
515 fn right(&self) -> PlanRef {
516 self.core.right.clone()
517 }
518
519 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
520 Self::with_core(generic::Join {
521 left,
522 right,
523 ..self.core.clone()
524 })
525 }
526
527 fn rewrite_with_left_right(
528 &self,
529 left: PlanRef,
530 left_col_change: ColIndexMapping,
531 right: PlanRef,
532 right_col_change: ColIndexMapping,
533 ) -> (Self, ColIndexMapping) {
534 let (new_on, new_output_indices) = {
535 let (mut map, _) = left_col_change.clone().into_parts();
536 let (mut right_map, _) = right_col_change.clone().into_parts();
537 for i in right_map.iter_mut().flatten() {
538 *i += left.schema().len();
539 }
540 map.append(&mut right_map);
541 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
542
543 let new_output_indices = self
544 .output_indices()
545 .iter()
546 .map(|&i| mapping.map(i))
547 .collect::<Vec<_>>();
548 let new_on = self.on().clone().rewrite_expr(&mut mapping);
549 (new_on, new_output_indices)
550 };
551
552 let join = Self::with_output_indices(
553 left,
554 right,
555 self.join_type(),
556 new_on,
557 new_output_indices.clone(),
558 );
559
560 let new_i2o = ColIndexMapping::with_remaining_columns(
561 &new_output_indices,
562 join.internal_column_num(),
563 );
564
565 let old_o2i = self.core.o2i_col_mapping();
566
567 let old_o2l = old_o2i
568 .composite(&self.core.i2l_col_mapping())
569 .composite(&left_col_change);
570 let old_o2r = old_o2i
571 .composite(&self.core.i2r_col_mapping())
572 .composite(&right_col_change);
573 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
574 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
575
576 let out_col_change = old_o2l
577 .composite(&new_l2o)
578 .union(&old_o2r.composite(&new_r2o));
579 (join, out_col_change)
580 }
581}
582
583impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
584
585impl ColPrunable for LogicalJoin {
586 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
587 let required_cols = required_cols
589 .iter()
590 .map(|i| self.output_indices()[*i])
591 .collect_vec();
592 let left_len = self.left().schema().fields.len();
593
594 let total_len = self.left().schema().len() + self.right().schema().len();
595 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
596
597 required_cols.iter().for_each(|&i| {
598 if self.is_right_join() {
599 resized_required_cols.insert(left_len + i);
600 } else {
601 resized_required_cols.insert(i);
602 }
603 });
604
605 let mut visitor = CollectInputRef::new(resized_required_cols);
608 self.on().visit_expr(&mut visitor);
609 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
610
611 let mut left_required_cols = Vec::new();
612 let mut right_required_cols = Vec::new();
613 left_right_required_cols.iter().for_each(|&i| {
614 if i < left_len {
615 left_required_cols.push(i);
616 } else {
617 right_required_cols.push(i - left_len);
618 }
619 });
620
621 let mut on = self.on().clone();
622 let mut mapping =
623 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
624 on = on.rewrite_expr(&mut mapping);
625
626 let new_output_indices = {
627 let required_inputs_in_output = if self.is_left_join() {
628 &left_required_cols
629 } else if self.is_right_join() {
630 &right_required_cols
631 } else {
632 &left_right_required_cols
633 };
634
635 let mapping =
636 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
637 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
638 };
639
640 LogicalJoin::with_output_indices(
641 self.left().prune_col(&left_required_cols, ctx),
642 self.right().prune_col(&right_required_cols, ctx),
643 self.join_type(),
644 on,
645 new_output_indices,
646 )
647 .into()
648 }
649}
650
651impl ExprRewritable<Logical> for LogicalJoin {
652 fn has_rewritable_expr(&self) -> bool {
653 true
654 }
655
656 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
657 let mut core = self.core.clone();
658 core.rewrite_exprs(r);
659 Self {
660 base: self.base.clone_with_new_plan_id(),
661 core,
662 }
663 .into()
664 }
665}
666
667impl ExprVisitable for LogicalJoin {
668 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
669 self.core.visit_exprs(v);
670 }
671}
672
673fn derive_predicate_from_eq_condition(
691 expr: &ExprImpl,
692 eq_condition: &EqJoinPredicate,
693 col_num: usize,
694 expr_is_left: bool,
695) -> Option<ExprImpl> {
696 if expr.is_impure() {
697 return None;
698 }
699 let eq_indices = eq_condition
700 .eq_indexes_typed()
701 .iter()
702 .filter_map(|(l, r)| {
703 if l.return_type() != r.return_type() {
704 None
705 } else if expr_is_left {
706 Some(l.index())
707 } else {
708 Some(r.index())
709 }
710 })
711 .collect_vec();
712 if expr
713 .collect_input_refs(col_num)
714 .ones()
715 .any(|index| !eq_indices.contains(&index))
716 {
717 return None;
719 }
720 let other_side_mapping = if expr_is_left {
723 eq_condition.eq_indexes_typed().into_iter().collect()
724 } else {
725 eq_condition
726 .eq_indexes_typed()
727 .into_iter()
728 .map(|(x, y)| (y, x))
729 .collect()
730 };
731 struct InputRefsRewriter {
732 mapping: HashMap<InputRef, InputRef>,
733 }
734 impl ExprRewriter for InputRefsRewriter {
735 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
736 self.mapping[&input_ref].clone().into()
737 }
738 }
739 Some(
740 InputRefsRewriter {
741 mapping: other_side_mapping,
742 }
743 .rewrite_expr(expr.clone()),
744 )
745}
746
747struct LookupJoinPredicateRewriter {
749 offset: usize,
750 mapping: Vec<usize>,
751}
752impl ExprRewriter for LookupJoinPredicateRewriter {
753 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
754 if input_ref.index() < self.offset {
755 input_ref.into()
756 } else {
757 InputRef::new(
758 self.mapping[input_ref.index() - self.offset] + self.offset,
759 input_ref.return_type(),
760 )
761 .into()
762 }
763 }
764}
765
766struct LookupJoinScanPredicateRewriter {
768 offset: usize,
769}
770impl ExprRewriter for LookupJoinScanPredicateRewriter {
771 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
772 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
773 }
774}
775
776impl PredicatePushdown for LogicalJoin {
777 fn predicate_pushdown(
801 &self,
802 predicate: Condition,
803 ctx: &mut PredicatePushdownContext,
804 ) -> PlanRef {
805 let mut predicate = {
807 let mut mapping = self.core.o2i_col_mapping();
808 predicate.rewrite_expr(&mut mapping)
809 };
810
811 let left_col_num = self.left().schema().len();
812 let right_col_num = self.right().schema().len();
813 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
814
815 let push_down_temporal_predicate = self.temporal_join_on().is_none();
816
817 let (left_from_filter, right_from_filter, on) = push_down_into_join(
818 &mut predicate,
819 left_col_num,
820 right_col_num,
821 join_type,
822 push_down_temporal_predicate,
823 );
824
825 let mut new_on = self.on().clone().and(on);
826 let (left_from_on, right_from_on) = push_down_join_condition(
827 &mut new_on,
828 left_col_num,
829 right_col_num,
830 join_type,
831 push_down_temporal_predicate,
832 );
833
834 let left_predicate = left_from_filter.and(left_from_on);
835 let right_predicate = right_from_filter.and(right_from_on);
836
837 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
839
840 let right_from_left = if matches!(
842 join_type,
843 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
844 ) {
845 Condition {
846 conjunctions: left_predicate
847 .conjunctions
848 .iter()
849 .filter_map(|expr| {
850 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
851 })
852 .collect(),
853 }
854 } else {
855 Condition::true_cond()
856 };
857
858 let left_from_right = if matches!(
860 join_type,
861 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
862 ) {
863 Condition {
864 conjunctions: right_predicate
865 .conjunctions
866 .iter()
867 .filter_map(|expr| {
868 derive_predicate_from_eq_condition(
869 expr,
870 &eq_condition,
871 right_col_num,
872 false,
873 )
874 })
875 .collect(),
876 }
877 } else {
878 Condition::true_cond()
879 };
880
881 let left_predicate = left_predicate.and(left_from_right);
882 let right_predicate = right_predicate.and(right_from_left);
883
884 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
885 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
886 let new_join = LogicalJoin::with_output_indices(
887 new_left,
888 new_right,
889 join_type,
890 new_on,
891 self.output_indices().clone(),
892 );
893
894 let mut mapping = self.core.i2o_col_mapping();
895 predicate = predicate.rewrite_expr(&mut mapping);
896 LogicalFilter::create(new_join.into(), predicate)
897 }
898}
899
900#[derive(Clone, Copy)]
901struct TemporalJoinScan<'a>(&'a LogicalScan);
902
903impl<'a> Deref for TemporalJoinScan<'a> {
904 type Target = LogicalScan;
905
906 fn deref(&self) -> &Self::Target {
907 self.0
908 }
909}
910
911impl LogicalJoin {
912 fn get_stream_input_for_hash_join(
913 &self,
914 predicate: &EqJoinPredicate,
915 ctx: &mut ToStreamContext,
916 ) -> Result<(StreamPlanRef, StreamPlanRef)> {
917 use super::stream::prelude::*;
918
919 let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
920 let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
921
922 let logical_right = try_enforce_locality_requirement(self.right(), &rhs_join_key_idx);
923 let mut right = logical_right.to_stream_with_dist_required(
924 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
925 ctx,
926 )?;
927 let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
928 let r2l =
929 predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
930 let l2r =
931 predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
932 let mut left;
933 let right_dist = right.distribution();
934 match right_dist {
935 Distribution::HashShard(_) => {
936 let left_dist = r2l
937 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
938 left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
939 }
940 Distribution::UpstreamHashShard(_, _) => {
941 left = logical_left.to_stream_with_dist_required(
942 &RequiredDist::shard_by_key(
943 self.left().schema().len(),
944 &predicate.left_eq_indexes(),
945 ),
946 ctx,
947 )?;
948 let left_dist = left.distribution();
949 match left_dist {
950 Distribution::HashShard(_) => {
951 let right_dist = l2r.rewrite_required_distribution(
952 &RequiredDist::PhysicalDist(left_dist.clone()),
953 );
954 right = right_dist.streaming_enforce_if_not_satisfies(right)?
955 }
956 Distribution::UpstreamHashShard(_, _) => {
957 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
958 .streaming_enforce_if_not_satisfies(left)?;
959 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
960 .streaming_enforce_if_not_satisfies(right)?;
961 }
962 _ => unreachable!(),
963 }
964 }
965 _ => unreachable!(),
966 }
967 Ok((left, right))
968 }
969
970 fn to_stream_hash_join(
971 &self,
972 predicate: EqJoinPredicate,
973 ctx: &mut ToStreamContext,
974 ) -> Result<StreamPlanRef> {
975 use super::stream::prelude::*;
976
977 assert!(predicate.has_eq());
978 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
979
980 let mut core = self.core.clone_with_inputs(left, right);
981 core.on = generic::JoinOn::EqPredicate(predicate);
982
983 let stream_hash_join = StreamHashJoin::new(core.clone())?;
992 let predicate = stream_hash_join.eq_join_predicate().clone();
993
994 let force_filter_inside_join = self
995 .base
996 .ctx()
997 .session_ctx()
998 .config()
999 .streaming_force_filter_inside_join();
1000
1001 let pull_filter = self.join_type() == JoinType::Inner
1002 && stream_hash_join.eq_join_predicate().has_non_eq()
1003 && stream_hash_join.inequality_pairs().is_empty()
1004 && (!force_filter_inside_join);
1005 if pull_filter {
1006 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1007
1008 let mut core = core;
1009 core.output_indices = default_indices.clone();
1010 let eq_cond = EqJoinPredicate::new(
1012 Condition::true_cond(),
1013 predicate.eq_keys().to_vec(),
1014 self.left().schema().len(),
1015 self.right().schema().len(),
1016 );
1017 core.on = generic::JoinOn::EqPredicate(eq_cond);
1018 let hash_join = StreamHashJoin::new(core)?.into();
1019 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1020 let plan = StreamFilter::new(logical_filter).into();
1021 if self.output_indices() != &default_indices {
1022 let logical_project = generic::Project::with_mapping(
1023 plan,
1024 ColIndexMapping::with_remaining_columns(
1025 self.output_indices(),
1026 self.internal_column_num(),
1027 ),
1028 );
1029 Ok(StreamProject::new(logical_project).into())
1030 } else {
1031 Ok(plan)
1032 }
1033 } else {
1034 Ok(stream_hash_join.into())
1035 }
1036 }
1037
1038 pub fn should_be_temporal_join(&self) -> bool {
1039 self.temporal_join_on().is_some()
1040 }
1041
1042 fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1043 if let Some(logical_scan) = self.core.right.as_logical_scan() {
1044 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1045 .then_some(TemporalJoinScan(logical_scan))
1046 } else {
1047 None
1048 }
1049 }
1050
1051 fn should_be_stream_temporal_join<'a>(
1052 &'a self,
1053 ctx: &ToStreamContext,
1054 ) -> Result<Option<TemporalJoinScan<'a>>> {
1055 Ok(if let Some(scan) = self.temporal_join_on() {
1056 if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1057 return Err(RwError::from(ErrorCode::NotSupported(
1058 "Temporal join with snapshot backfill not supported".into(),
1059 "Please use arrangement backfill".into(),
1060 )));
1061 }
1062 if scan.cross_database() {
1063 return Err(RwError::from(ErrorCode::NotSupported(
1064 "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1065 "Please ensure both tables are in the same database".into(),
1066 )));
1067 }
1068 Some(scan)
1069 } else {
1070 None
1071 })
1072 }
1073
1074 fn to_stream_temporal_join_with_index_selection(
1075 &self,
1076 logical_scan: TemporalJoinScan<'_>,
1077 predicate: EqJoinPredicate,
1078 ctx: &mut ToStreamContext,
1079 ) -> Result<StreamPlanRef> {
1080 let mut result_plan: Result<StreamTemporalJoin> =
1082 self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1083 if let Ok(temporal_join) = &result_plan
1085 && temporal_join.eq_join_predicate().eq_indexes().len()
1086 == logical_scan.primary_key().len()
1087 {
1088 return result_plan.map(|x| x.into());
1089 }
1090 if self
1091 .core
1092 .ctx()
1093 .session_ctx()
1094 .config()
1095 .enable_index_selection()
1096 {
1097 let indexes = logical_scan.table_indexes();
1098 for index in indexes {
1099 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1101 let index_scan: PlanRef = index_scan.into();
1102 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1103 if let Ok(temporal_join) = that.to_stream_temporal_join(
1104 that.temporal_join_on().expect(
1105 "index scan created from temporal join scan must also be temporal join",
1106 ),
1107 predicate.clone(),
1108 ctx,
1109 ) {
1110 match &result_plan {
1111 Err(_) => result_plan = Ok(temporal_join),
1112 Ok(prev_temporal_join) => {
1113 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1115 < temporal_join.eq_join_predicate().eq_indexes().len()
1116 {
1117 result_plan = Ok(temporal_join)
1118 }
1119 }
1120 }
1121 }
1122 }
1123 }
1124 }
1125
1126 result_plan.map(|x| x.into())
1127 }
1128
1129 fn temporal_join_scan_predicate_pull_up(
1130 logical_scan: TemporalJoinScan<'_>,
1131 predicate: EqJoinPredicate,
1132 output_indices: &[usize],
1133 left_schema_len: usize,
1134 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1135 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1137 let o2r = if let Some(project_expr) = project_expr {
1139 project_expr
1140 .into_iter()
1141 .map(|x| x.as_input_ref().unwrap().index)
1142 .collect_vec()
1143 } else {
1144 (0..logical_scan.output_col_idx().len()).collect_vec()
1145 };
1146 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1147 offset: left_schema_len,
1148 mapping: o2r.clone(),
1149 };
1150
1151 let new_eq_cond = predicate
1152 .eq_cond()
1153 .rewrite_expr(&mut join_predicate_rewriter);
1154
1155 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1156 offset: left_schema_len,
1157 };
1158
1159 let new_other_cond = predicate
1160 .other_cond()
1161 .clone()
1162 .rewrite_expr(&mut join_predicate_rewriter)
1163 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1164
1165 let new_join_on = new_eq_cond.and(new_other_cond);
1166
1167 let new_predicate = EqJoinPredicate::create(
1168 left_schema_len,
1169 new_scan.schema().len(),
1170 new_join_on.clone(),
1171 );
1172
1173 let new_join_output_indices = output_indices
1176 .iter()
1177 .map(|&x| {
1178 if x < left_schema_len {
1179 x
1180 } else {
1181 o2r[x - left_schema_len] + left_schema_len
1182 }
1183 })
1184 .collect_vec();
1185
1186 let new_stream_table_scan =
1187 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1188 Ok((
1189 new_stream_table_scan,
1190 new_predicate,
1191 new_join_on,
1192 new_join_output_indices,
1193 ))
1194 }
1195
1196 fn to_stream_temporal_join(
1197 &self,
1198 logical_scan: TemporalJoinScan<'_>,
1199 predicate: EqJoinPredicate,
1200 ctx: &mut ToStreamContext,
1201 ) -> Result<StreamTemporalJoin> {
1202 use super::stream::prelude::*;
1203
1204 assert!(predicate.has_eq());
1205
1206 let table = logical_scan.table();
1207 let output_column_ids = logical_scan.output_column_ids();
1208
1209 let order_col_ids = table.order_column_ids();
1212 let dist_key = table.distribution_key.clone();
1213
1214 let mut dist_key_in_order_key_pos = vec![];
1215 for d in dist_key {
1216 let pos = table
1217 .order_column_indices()
1218 .position(|x| x == d)
1219 .expect("dist_key must in order_key");
1220 dist_key_in_order_key_pos.push(pos);
1221 }
1222 let shortest_prefix_len = dist_key_in_order_key_pos
1224 .iter()
1225 .max()
1226 .map_or(0, |pos| pos + 1);
1227
1228 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1230 for order_col_id in order_col_ids {
1231 let mut found = false;
1232 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1233 if order_col_id == output_column_ids[eq_idx] {
1234 reorder_idx.push(i);
1235 found = true;
1236 break;
1237 }
1238 }
1239 if !found {
1240 break;
1241 }
1242 }
1243 if reorder_idx.len() < shortest_prefix_len {
1244 return Err(RwError::from(ErrorCode::NotSupported(
1245 "Temporal join requires the equivalence join condition includes the key columns that form the distribution key of the lookup table".into(),
1246 concat!(
1247 "Use DESCRIBE <table_name> to view the table's key information.\n",
1248 "You can create an index on the lookup table to facilitate the temporal join if necessary."
1249 ).into(),
1250 )));
1251 }
1252 let lookup_prefix_len = reorder_idx.len();
1253 let predicate = predicate.reorder(&reorder_idx);
1254
1255 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1256 RequiredDist::single()
1257 } else {
1258 let left_eq_indexes = predicate.left_eq_indexes();
1259 let left_dist_key = dist_key_in_order_key_pos
1260 .iter()
1261 .map(|pos| left_eq_indexes[*pos])
1262 .collect_vec();
1263
1264 RequiredDist::hash_shard(&left_dist_key)
1265 };
1266
1267 let lhs_join_key_idx = predicate
1268 .eq_indexes()
1269 .into_iter()
1270 .map(|(l, _)| l)
1271 .collect_vec();
1272 let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
1273 let left = logical_left.to_stream(ctx)?;
1274 let left = required_dist.stream_enforce(left);
1276
1277 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1278 Self::temporal_join_scan_predicate_pull_up(
1279 logical_scan,
1280 predicate,
1281 self.output_indices(),
1282 self.left().schema().len(),
1283 )?;
1284
1285 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1286 if !new_predicate.has_eq() {
1287 return Err(RwError::from(ErrorCode::NotSupported(
1288 "Temporal join requires a non trivial join condition".into(),
1289 "Please remove the false condition of the join".into(),
1290 )));
1291 }
1292
1293 let new_logical_join = generic::Join::new(
1295 left,
1296 right,
1297 new_join_on,
1298 self.join_type(),
1299 new_join_output_indices,
1300 );
1301
1302 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1303
1304 let mut new_logical_join = new_logical_join;
1305 new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1306 StreamTemporalJoin::new(new_logical_join, false)
1307 }
1308
1309 fn to_stream_nested_loop_temporal_join(
1310 &self,
1311 logical_scan: TemporalJoinScan<'_>,
1312 predicate: EqJoinPredicate,
1313 ctx: &mut ToStreamContext,
1314 ) -> Result<StreamPlanRef> {
1315 use super::stream::prelude::*;
1316 assert!(!predicate.has_eq());
1317
1318 let left = self.left().to_stream_with_dist_required(
1319 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1320 ctx,
1321 )?;
1322 assert!(left.as_stream_exchange().is_some());
1323
1324 if self.join_type() != JoinType::Inner {
1325 return Err(RwError::from(ErrorCode::NotSupported(
1326 "Temporal join requires an inner join".into(),
1327 "Please use an inner join".into(),
1328 )));
1329 }
1330
1331 if !left.append_only() {
1332 return Err(RwError::from(ErrorCode::NotSupported(
1333 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1334 "Please ensure the left hash side is append only".into(),
1335 )));
1336 }
1337
1338 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1339 Self::temporal_join_scan_predicate_pull_up(
1340 logical_scan,
1341 predicate,
1342 self.output_indices(),
1343 self.left().schema().len(),
1344 )?;
1345
1346 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1347
1348 let new_logical_join = generic::Join::new(
1350 left,
1351 right,
1352 new_join_on,
1353 self.join_type(),
1354 new_join_output_indices,
1355 );
1356
1357 let mut new_logical_join = new_logical_join;
1358 new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1359 Ok(StreamTemporalJoin::new(new_logical_join, true)?.into())
1360 }
1361
1362 fn to_stream_dynamic_filter(
1363 &self,
1364 predicate: Condition,
1365 ctx: &mut ToStreamContext,
1366 ) -> Result<Option<StreamPlanRef>> {
1367 use super::stream::prelude::*;
1368
1369 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1375 return Ok(None);
1376 }
1377
1378 if !self.right().max_one_row() {
1380 return Ok(None);
1381 }
1382 if self.right().schema().len() != 1 {
1383 return Ok(None);
1384 }
1385
1386 if predicate.conjunctions.len() > 1 {
1388 return Ok(None);
1389 }
1390 let expr: ExprImpl = predicate.into();
1391 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1392 Some(v) => v,
1393 None => return Ok(None),
1394 };
1395
1396 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1397 && right_ref.index == self.left().schema().len() ;
1398 if !condition_cross_inputs {
1399 return Ok(None);
1401 }
1402
1403 if self.left().schema().fields()[left_ref.index].data_type
1405 != self.right().schema().fields()[0].data_type
1406 {
1407 return Ok(None);
1408 }
1409
1410 let all_output_from_left = self
1412 .output_indices()
1413 .iter()
1414 .all(|i| *i < self.left().schema().len());
1415 if !all_output_from_left {
1416 return Ok(None);
1417 }
1418
1419 let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1420 let right = self.right().to_stream_with_dist_required(
1421 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1422 ctx,
1423 )?;
1424
1425 assert!(right.as_stream_exchange().is_some());
1426 assert_eq!(
1427 *right.inputs().iter().exactly_one().unwrap().distribution(),
1428 Distribution::Single
1429 );
1430
1431 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1432 let plan = StreamDynamicFilter::new(core)?.into();
1433 if self
1435 .output_indices()
1436 .iter()
1437 .copied()
1438 .ne(0..self.left().schema().len())
1439 {
1440 let logical_project = generic::Project::with_mapping(
1443 plan,
1444 ColIndexMapping::with_remaining_columns(
1445 self.output_indices(),
1446 self.left().schema().len(),
1447 ),
1448 );
1449 Ok(Some(StreamProject::new(logical_project).into()))
1450 } else {
1451 Ok(Some(plan))
1452 }
1453 }
1454
1455 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1456 let predicate = EqJoinPredicate::create(
1457 self.left().schema().len(),
1458 self.right().schema().len(),
1459 self.on().clone(),
1460 );
1461 assert!(predicate.has_eq());
1462
1463 let join = self
1464 .core
1465 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1466
1467 Ok(self
1468 .to_batch_lookup_join(predicate, join)?
1469 .expect("Fail to convert to lookup join")
1470 .into())
1471 }
1472
1473 fn to_stream_asof_join(
1474 &self,
1475 predicate: EqJoinPredicate,
1476 ctx: &mut ToStreamContext,
1477 ) -> Result<StreamPlanRef> {
1478 use super::stream::prelude::*;
1479
1480 if predicate.eq_keys().is_empty() {
1481 return Err(ErrorCode::InvalidInputSyntax(
1482 "AsOf join requires at least 1 equal condition".to_owned(),
1483 )
1484 .into());
1485 }
1486
1487 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1488 let left_len = left.schema().len();
1489 let mut core = self.core.clone_with_inputs(left, right);
1490 core.on = generic::JoinOn::EqPredicate(predicate);
1491
1492 let inequality_desc = Self::get_inequality_desc_from_predicate(
1493 core.on
1494 .as_eq_predicate_ref()
1495 .expect("core predicate must exist")
1496 .other_cond()
1497 .clone(),
1498 left_len,
1499 )?;
1500
1501 Ok(StreamAsOfJoin::new(core, inequality_desc)?.into())
1502 }
1503
1504 fn to_batch_hash_join(
1506 &self,
1507 logical_join: generic::Join<BatchPlanRef>,
1508 predicate: EqJoinPredicate,
1509 ) -> Result<BatchPlanRef> {
1510 use super::batch::prelude::*;
1511
1512 let left_schema_len = logical_join.left.schema().len();
1513 let asof_desc = self
1514 .is_asof_join()
1515 .then(|| {
1516 Self::get_inequality_desc_from_predicate(
1517 predicate.other_cond().clone(),
1518 left_schema_len,
1519 )
1520 })
1521 .transpose()?;
1522
1523 let logical_join = generic::Join {
1524 on: generic::JoinOn::EqPredicate(predicate),
1525 ..logical_join
1526 };
1527 let batch_join = BatchHashJoin::new(logical_join, asof_desc);
1528 Ok(batch_join.into())
1529 }
1530
1531 pub fn get_inequality_desc_from_predicate(
1532 predicate: Condition,
1533 left_input_len: usize,
1534 ) -> Result<AsOfJoinDesc> {
1535 let expr: ExprImpl = predicate.into();
1536 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1537 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1538 {
1539 Ok(AsOfJoinDesc {
1540 left_idx: left_input_ref.index() as u32,
1541 right_idx: (right_input_ref.index() - left_input_len) as u32,
1542 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1543 })
1544 } else {
1545 bail!("inequal condition from the same side should be push down in optimizer");
1546 }
1547 } else {
1548 Err(ErrorCode::InvalidInputSyntax(
1549 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1550 )
1551 .into())
1552 }
1553 }
1554
1555 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1556 match expr_type {
1557 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1558 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1559 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1560 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1561 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1562 "Invalid comparison type: {}",
1563 expr_type.as_str_name()
1564 ))
1565 .into()),
1566 }
1567 }
1568}
1569
1570impl ToBatch for LogicalJoin {
1571 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1572 let predicate = EqJoinPredicate::create(
1573 self.left().schema().len(),
1574 self.right().schema().len(),
1575 self.on().clone(),
1576 );
1577
1578 let batch_join = self
1579 .core
1580 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1581
1582 let ctx = self.base.ctx();
1583 let config = ctx.session_ctx().config();
1584
1585 if predicate.has_eq() {
1586 if !predicate.eq_keys_are_type_aligned() {
1587 return Err(ErrorCode::InternalError(format!(
1588 "Join eq keys are not aligned for predicate: {predicate:?}"
1589 ))
1590 .into());
1591 }
1592 if config.batch_enable_lookup_join()
1593 && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1594 predicate.clone(),
1595 batch_join.clone(),
1596 )?
1597 {
1598 return Ok(lookup_join.into());
1599 }
1600 self.to_batch_hash_join(batch_join, predicate)
1601 } else if self.is_asof_join() {
1602 Err(ErrorCode::InvalidInputSyntax(
1603 "AsOf join requires at least 1 equal condition".to_owned(),
1604 )
1605 .into())
1606 } else {
1607 Ok(BatchNestedLoopJoin::new(batch_join).into())
1609 }
1610 }
1611}
1612
1613impl ToStream for LogicalJoin {
1614 fn to_stream(
1615 &self,
1616 ctx: &mut ToStreamContext,
1617 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1618 if self
1619 .on()
1620 .conjunctions
1621 .iter()
1622 .any(|cond| cond.count_nows() > 0)
1623 {
1624 return Err(ErrorCode::NotSupported(
1625 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1626 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1627 }
1628
1629 let predicate = EqJoinPredicate::create(
1630 self.left().schema().len(),
1631 self.right().schema().len(),
1632 self.on().clone(),
1633 );
1634
1635 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1636 self.to_stream_asof_join(predicate, ctx)
1637 } else if predicate.has_eq() {
1638 if !predicate.eq_keys_are_type_aligned() {
1639 return Err(ErrorCode::InternalError(format!(
1640 "Join eq keys are not aligned for predicate: {predicate:?}"
1641 ))
1642 .into());
1643 }
1644
1645 if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1646 self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1647 } else {
1648 self.to_stream_hash_join(predicate, ctx)
1649 }
1650 } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1651 self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1652 } else if let Some(dynamic_filter) =
1653 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1654 {
1655 Ok(dynamic_filter)
1656 } else {
1657 Err(RwError::from(ErrorCode::NotSupported(
1658 "streaming nested-loop join".to_owned(),
1659 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1660 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1661 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1662 )))
1663 }
1664 }
1665
1666 fn logical_rewrite_for_stream(
1667 &self,
1668 ctx: &mut RewriteStreamContext,
1669 ) -> Result<(PlanRef, ColIndexMapping)> {
1670 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1671 let left_len = left.schema().len();
1672 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1673 let (join, out_col_change) = self.rewrite_with_left_right(
1674 left.clone(),
1675 left_col_change,
1676 right.clone(),
1677 right_col_change,
1678 );
1679
1680 let mapping = ColIndexMapping::with_remaining_columns(
1681 join.output_indices(),
1682 join.internal_column_num(),
1683 );
1684
1685 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1686 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1687
1688 let mut left_to_add = left
1690 .expect_stream_key()
1691 .iter()
1692 .cloned()
1693 .filter(|i| l2o.try_map(*i).is_none())
1694 .collect_vec();
1695
1696 let mut right_to_add = right
1697 .expect_stream_key()
1698 .iter()
1699 .filter(|&&i| r2o.try_map(i).is_none())
1700 .map(|&i| i + left_len)
1701 .collect_vec();
1702
1703 let right_len = right.schema().len();
1706 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1707
1708 let either_or_both = self.core.add_which_join_key_to_pk();
1709
1710 for (lk, rk) in eq_predicate.eq_indexes() {
1711 match either_or_both {
1712 EitherOrBoth::Left(_) => {
1713 if l2o.try_map(lk).is_none() {
1714 left_to_add.push(lk);
1715 }
1716 }
1717 EitherOrBoth::Right(_) => {
1718 if r2o.try_map(rk).is_none() {
1719 right_to_add.push(rk + left_len)
1720 }
1721 }
1722 EitherOrBoth::Both(_, _) => {
1723 if l2o.try_map(lk).is_none() {
1724 left_to_add.push(lk);
1725 }
1726 if r2o.try_map(rk).is_none() {
1727 right_to_add.push(rk + left_len)
1728 }
1729 }
1730 };
1731 }
1732 let left_to_add = left_to_add.into_iter().unique();
1733 let right_to_add = right_to_add.into_iter().unique();
1734 let mut new_output_indices = join.output_indices().clone();
1737 if !join.is_right_join() {
1738 new_output_indices.extend(left_to_add);
1739 }
1740 if !join.is_left_join() {
1741 new_output_indices.extend(right_to_add);
1742 }
1743
1744 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1745
1746 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1747 let l2o = join_with_pk
1750 .core
1751 .l2i_col_mapping()
1752 .composite(&join_with_pk.core.i2o_col_mapping());
1753 let r2o = join_with_pk
1754 .core
1755 .r2i_col_mapping()
1756 .composite(&join_with_pk.core.i2o_col_mapping());
1757 let left_right_stream_keys = join_with_pk
1758 .left()
1759 .expect_stream_key()
1760 .iter()
1761 .map(|i| l2o.map(*i))
1762 .chain(
1763 join_with_pk
1764 .right()
1765 .expect_stream_key()
1766 .iter()
1767 .map(|i| r2o.map(*i)),
1768 )
1769 .collect_vec();
1770 let plan: PlanRef = join_with_pk.into();
1771 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1772 } else {
1773 join_with_pk.into()
1774 };
1775
1776 Ok((plan, out_col_change))
1778 }
1779
1780 fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1781 let mut ctx = ToStreamContext::new(false);
1782 if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1784 let o2i_mapping = self.core.o2i_col_mapping();
1786 let left_input_columns = columns
1787 .iter()
1788 .map(|&col| o2i_mapping.try_map(col))
1789 .collect::<Option<Vec<usize>>>()?;
1790 if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1791 return Some(
1792 self.clone_with_left_right(better_left_plan, self.right())
1793 .into(),
1794 );
1795 }
1796 }
1797 None
1798 }
1799}
1800
1801#[cfg(test)]
1802mod tests {
1803
1804 use std::collections::HashSet;
1805
1806 use risingwave_common::catalog::{Field, Schema};
1807 use risingwave_common::types::{DataType, Datum};
1808 use risingwave_pb::expr::expr_node::Type;
1809
1810 use super::*;
1811 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1812 use crate::optimizer::optimizer_context::OptimizerContext;
1813 use crate::optimizer::plan_node::LogicalValues;
1814 use crate::optimizer::property::FunctionalDependency;
1815
1816 #[tokio::test]
1830 async fn test_prune_join() {
1831 let ty = DataType::Int32;
1832 let ctx = OptimizerContext::mock().await;
1833 let fields: Vec<Field> = (1..7)
1834 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1835 .collect();
1836 let left = LogicalValues::new(
1837 vec![],
1838 Schema {
1839 fields: fields[0..3].to_vec(),
1840 },
1841 ctx.clone(),
1842 );
1843 let right = LogicalValues::new(
1844 vec![],
1845 Schema {
1846 fields: fields[3..6].to_vec(),
1847 },
1848 ctx,
1849 );
1850 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1851 FunctionCall::new(
1852 Type::Equal,
1853 vec![
1854 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1855 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1856 ],
1857 )
1858 .unwrap(),
1859 ));
1860 let join_type = JoinType::Inner;
1861 let join: PlanRef = LogicalJoin::new(
1862 left.into(),
1863 right.into(),
1864 join_type,
1865 Condition::with_expr(on),
1866 )
1867 .into();
1868
1869 let required_cols = vec![2, 3];
1871 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1872
1873 let join = plan.as_logical_join().unwrap();
1875 assert_eq!(join.schema().fields().len(), 2);
1876 assert_eq!(join.schema().fields()[0], fields[2]);
1877 assert_eq!(join.schema().fields()[1], fields[3]);
1878
1879 let expr: ExprImpl = join.on().clone().into();
1880 let call = expr.as_function_call().unwrap();
1881 assert_eq_input_ref!(&call.inputs()[0], 0);
1882 assert_eq_input_ref!(&call.inputs()[1], 2);
1883
1884 let left = join.left();
1885 let left = left.as_logical_values().unwrap();
1886 assert_eq!(left.schema().fields(), &fields[1..3]);
1887 let right = join.right();
1888 let right = right.as_logical_values().unwrap();
1889 assert_eq!(right.schema().fields(), &fields[3..4]);
1890 }
1891
1892 #[tokio::test]
1894 async fn test_prune_semi_join() {
1895 let ty = DataType::Int32;
1896 let ctx = OptimizerContext::mock().await;
1897 let fields: Vec<Field> = (1..7)
1898 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1899 .collect();
1900 let left = LogicalValues::new(
1901 vec![],
1902 Schema {
1903 fields: fields[0..3].to_vec(),
1904 },
1905 ctx.clone(),
1906 );
1907 let right = LogicalValues::new(
1908 vec![],
1909 Schema {
1910 fields: fields[3..6].to_vec(),
1911 },
1912 ctx,
1913 );
1914 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1915 FunctionCall::new(
1916 Type::Equal,
1917 vec![
1918 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1919 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1920 ],
1921 )
1922 .unwrap(),
1923 ));
1924 for join_type in [
1925 JoinType::LeftSemi,
1926 JoinType::RightSemi,
1927 JoinType::LeftAnti,
1928 JoinType::RightAnti,
1929 ] {
1930 let join = LogicalJoin::new(
1931 left.clone().into(),
1932 right.clone().into(),
1933 join_type,
1934 Condition::with_expr(on.clone()),
1935 );
1936
1937 let offset = if join.is_right_join() { 3 } else { 0 };
1938 let join: PlanRef = join.into();
1939 let required_cols = vec![0];
1941 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1943 let as_plan = plan.as_logical_join().unwrap();
1944 assert_eq!(as_plan.schema().fields().len(), 1);
1946 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1947
1948 let required_cols = vec![0, 1, 2];
1950 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1952 let as_plan = plan.as_logical_join().unwrap();
1953 assert_eq!(as_plan.schema().fields().len(), 3);
1955 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1956 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1957 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1958 }
1959 }
1960
1961 #[tokio::test]
1974 async fn test_prune_join_no_project() {
1975 let ty = DataType::Int32;
1976 let ctx = OptimizerContext::mock().await;
1977 let fields: Vec<Field> = (1..7)
1978 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1979 .collect();
1980 let left = LogicalValues::new(
1981 vec![],
1982 Schema {
1983 fields: fields[0..3].to_vec(),
1984 },
1985 ctx.clone(),
1986 );
1987 let right = LogicalValues::new(
1988 vec![],
1989 Schema {
1990 fields: fields[3..6].to_vec(),
1991 },
1992 ctx,
1993 );
1994 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1995 FunctionCall::new(
1996 Type::Equal,
1997 vec![
1998 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1999 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2000 ],
2001 )
2002 .unwrap(),
2003 ));
2004 let join_type = JoinType::Inner;
2005 let join: PlanRef = LogicalJoin::new(
2006 left.into(),
2007 right.into(),
2008 join_type,
2009 Condition::with_expr(on),
2010 )
2011 .into();
2012
2013 let required_cols = vec![1, 3];
2015 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2016
2017 let join = plan.as_logical_join().unwrap();
2019 assert_eq!(join.schema().fields().len(), 2);
2020 assert_eq!(join.schema().fields()[0], fields[1]);
2021 assert_eq!(join.schema().fields()[1], fields[3]);
2022
2023 let expr: ExprImpl = join.on().clone().into();
2024 let call = expr.as_function_call().unwrap();
2025 assert_eq_input_ref!(&call.inputs()[0], 0);
2026 assert_eq_input_ref!(&call.inputs()[1], 1);
2027
2028 let left = join.left();
2029 let left = left.as_logical_values().unwrap();
2030 assert_eq!(left.schema().fields(), &fields[1..2]);
2031 let right = join.right();
2032 let right = right.as_logical_values().unwrap();
2033 assert_eq!(right.schema().fields(), &fields[3..4]);
2034 }
2035
2036 #[tokio::test]
2050 async fn test_join_to_batch() {
2051 let ctx = OptimizerContext::mock().await;
2052 let fields: Vec<Field> = (1..7)
2053 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2054 .collect();
2055 let left = LogicalValues::new(
2056 vec![],
2057 Schema {
2058 fields: fields[0..3].to_vec(),
2059 },
2060 ctx.clone(),
2061 );
2062 let right = LogicalValues::new(
2063 vec![],
2064 Schema {
2065 fields: fields[3..6].to_vec(),
2066 },
2067 ctx,
2068 );
2069
2070 fn input_ref(i: usize) -> ExprImpl {
2071 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2072 }
2073 let eq_cond = ExprImpl::FunctionCall(Box::new(
2074 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2075 ));
2076 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2077 FunctionCall::new(
2078 Type::Equal,
2079 vec![
2080 input_ref(2),
2081 ExprImpl::Literal(Box::new(Literal::new(
2082 Datum::Some(42_i32.into()),
2083 DataType::Int32,
2084 ))),
2085 ],
2086 )
2087 .unwrap(),
2088 ));
2089 let on_cond = ExprImpl::FunctionCall(Box::new(
2091 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2092 ));
2093
2094 let join_type = JoinType::Inner;
2095 let logical_join = LogicalJoin::new(
2096 left.into(),
2097 right.into(),
2098 join_type,
2099 Condition::with_expr(on_cond),
2100 );
2101
2102 let result = logical_join.to_batch().unwrap();
2104
2105 let hash_join = result.as_batch_hash_join().unwrap();
2107 assert_eq!(
2108 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2109 eq_cond
2110 );
2111 assert_eq!(
2112 *hash_join
2113 .eq_join_predicate()
2114 .non_eq_cond()
2115 .conjunctions
2116 .first()
2117 .unwrap(),
2118 non_eq_cond
2119 );
2120 }
2121
2122 #[tokio::test]
2135 #[ignore] async fn test_join_to_stream() {
2138 }
2206 #[tokio::test]
2220 async fn test_join_column_prune_with_order_required() {
2221 let ty = DataType::Int32;
2222 let ctx = OptimizerContext::mock().await;
2223 let fields: Vec<Field> = (1..7)
2224 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2225 .collect();
2226 let left = LogicalValues::new(
2227 vec![],
2228 Schema {
2229 fields: fields[0..3].to_vec(),
2230 },
2231 ctx.clone(),
2232 );
2233 let right = LogicalValues::new(
2234 vec![],
2235 Schema {
2236 fields: fields[3..6].to_vec(),
2237 },
2238 ctx,
2239 );
2240 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2241 FunctionCall::new(
2242 Type::Equal,
2243 vec![
2244 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2245 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2246 ],
2247 )
2248 .unwrap(),
2249 ));
2250 let join_type = JoinType::Inner;
2251 let join: PlanRef = LogicalJoin::new(
2252 left.into(),
2253 right.into(),
2254 join_type,
2255 Condition::with_expr(on),
2256 )
2257 .into();
2258
2259 let required_cols = vec![3, 2];
2261 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2262
2263 let join = plan.as_logical_join().unwrap();
2265 assert_eq!(join.schema().fields().len(), 2);
2266 assert_eq!(join.schema().fields()[0], fields[3]);
2267 assert_eq!(join.schema().fields()[1], fields[2]);
2268
2269 let expr: ExprImpl = join.on().clone().into();
2270 let call = expr.as_function_call().unwrap();
2271 assert_eq_input_ref!(&call.inputs()[0], 0);
2272 assert_eq_input_ref!(&call.inputs()[1], 2);
2273
2274 let left = join.left();
2275 let left = left.as_logical_values().unwrap();
2276 assert_eq!(left.schema().fields(), &fields[1..3]);
2277 let right = join.right();
2278 let right = right.as_logical_values().unwrap();
2279 assert_eq!(right.schema().fields(), &fields[3..4]);
2280 }
2281
2282 #[tokio::test]
2283 async fn fd_derivation_inner_outer_join() {
2284 let ctx = OptimizerContext::mock().await;
2307 let left = {
2308 let fields: Vec<Field> = vec![
2309 Field::with_name(DataType::Int32, "l0"),
2310 Field::with_name(DataType::Int32, "l1"),
2311 ];
2312 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2313 values
2315 .base
2316 .functional_dependency_mut()
2317 .add_functional_dependency_by_column_indices(&[0], &[1]);
2318 values
2319 };
2320 let right = {
2321 let fields: Vec<Field> = vec![
2322 Field::with_name(DataType::Int32, "r0"),
2323 Field::with_name(DataType::Int32, "r1"),
2324 Field::with_name(DataType::Int32, "r2"),
2325 ];
2326 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2327 values
2329 .base
2330 .functional_dependency_mut()
2331 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2332 values
2333 };
2334 let on: ExprImpl = FunctionCall::new(
2336 Type::And,
2337 vec![
2338 FunctionCall::new(
2339 Type::Equal,
2340 vec![
2341 InputRef::new(0, DataType::Int32).into(),
2342 ExprImpl::literal_int(0),
2343 ],
2344 )
2345 .unwrap()
2346 .into(),
2347 FunctionCall::new(
2348 Type::Equal,
2349 vec![
2350 InputRef::new(1, DataType::Int32).into(),
2351 InputRef::new(3, DataType::Int32).into(),
2352 ],
2353 )
2354 .unwrap()
2355 .into(),
2356 ],
2357 )
2358 .unwrap()
2359 .into();
2360 let expected_fd_set = [
2361 (
2362 JoinType::Inner,
2363 [
2364 FunctionalDependency::with_indices(5, &[0], &[1]),
2366 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2368 FunctionalDependency::with_indices(5, &[], &[0]),
2370 FunctionalDependency::with_indices(5, &[1], &[3]),
2372 FunctionalDependency::with_indices(5, &[3], &[1]),
2373 ]
2374 .into_iter()
2375 .collect::<HashSet<_>>(),
2376 ),
2377 (JoinType::FullOuter, HashSet::new()),
2378 (
2379 JoinType::RightOuter,
2380 [
2381 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2383 ]
2384 .into_iter()
2385 .collect::<HashSet<_>>(),
2386 ),
2387 (
2388 JoinType::LeftOuter,
2389 [
2390 FunctionalDependency::with_indices(5, &[0], &[1]),
2392 ]
2393 .into_iter()
2394 .collect::<HashSet<_>>(),
2395 ),
2396 (
2397 JoinType::LeftSemi,
2398 [
2399 FunctionalDependency::with_indices(2, &[0], &[1]),
2401 ]
2402 .into_iter()
2403 .collect::<HashSet<_>>(),
2404 ),
2405 (
2406 JoinType::LeftAnti,
2407 [
2408 FunctionalDependency::with_indices(2, &[0], &[1]),
2410 ]
2411 .into_iter()
2412 .collect::<HashSet<_>>(),
2413 ),
2414 (
2415 JoinType::RightSemi,
2416 [
2417 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2419 ]
2420 .into_iter()
2421 .collect::<HashSet<_>>(),
2422 ),
2423 (
2424 JoinType::RightAnti,
2425 [
2426 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2428 ]
2429 .into_iter()
2430 .collect::<HashSet<_>>(),
2431 ),
2432 ];
2433
2434 for (join_type, expected_res) in expected_fd_set {
2435 let join = LogicalJoin::new(
2436 left.clone().into(),
2437 right.clone().into(),
2438 join_type,
2439 Condition::with_expr(on.clone()),
2440 );
2441 let fd_set = join
2442 .functional_dependency()
2443 .as_dependencies()
2444 .iter()
2445 .cloned()
2446 .collect::<HashSet<_>>();
2447 assert_eq!(fd_set, expected_res);
2448 }
2449 }
2450}