1use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31 BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
32 PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject, ToBatch,
33 ToStream, generic,
34};
35use crate::error::{ErrorCode, Result, RwError};
36use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
37use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
38use crate::optimizer::plan_node::generic::DynamicFilter;
39use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
40use crate::optimizer::plan_node::utils::IndicesDisplay;
41use crate::optimizer::plan_node::{
42 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
43 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
44 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
45};
46use crate::optimizer::plan_visitor::LogicalCardinalityExt;
47use crate::optimizer::property::{Distribution, RequiredDist};
48use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct LogicalJoin {
58 pub base: PlanBase<Logical>,
59 core: generic::Join<PlanRef>,
60}
61
62impl Distill for LogicalJoin {
63 fn distill<'a>(&self) -> XmlNode<'a> {
64 let verbose = self.base.ctx().is_explain_verbose();
65 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
66 vec.push(("type", Pretty::debug(&self.join_type())));
67
68 let concat_schema = self.core.concat_schema();
69 let cond = Pretty::debug(&ConditionDisplay {
70 condition: self.on(),
71 input_schema: &concat_schema,
72 });
73 vec.push(("on", cond));
74
75 if verbose {
76 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
77 vec.push(("output", data));
78 }
79
80 childless_record("LogicalJoin", vec)
81 }
82}
83
84impl LogicalJoin {
85 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
86 let core = generic::Join::with_full_output(left, right, join_type, on);
87 Self::with_core(core)
88 }
89
90 pub(crate) fn with_output_indices(
91 left: PlanRef,
92 right: PlanRef,
93 join_type: JoinType,
94 on: Condition,
95 output_indices: Vec<usize>,
96 ) -> Self {
97 let core = generic::Join::new(left, right, on, join_type, output_indices);
98 Self::with_core(core)
99 }
100
101 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
102 let base = PlanBase::new_logical_with_core(&core);
103 LogicalJoin { base, core }
104 }
105
106 pub fn create(
107 left: PlanRef,
108 right: PlanRef,
109 join_type: JoinType,
110 on_clause: ExprImpl,
111 ) -> PlanRef {
112 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
113 }
114
115 pub fn internal_column_num(&self) -> usize {
116 self.core.internal_column_num()
117 }
118
119 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
120 self.core.i2l_col_mapping_ignore_join_type()
121 }
122
123 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
124 self.core.i2r_col_mapping_ignore_join_type()
125 }
126
127 pub fn on(&self) -> &Condition {
129 &self.core.on
130 }
131
132 pub fn core(&self) -> &generic::Join<PlanRef> {
133 &self.core
134 }
135
136 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
138 let input_refs = self
139 .core
140 .on
141 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
142 let index_group = input_refs
143 .ones()
144 .chunk_by(|i| *i < self.core.left.schema().len());
145 let left_index = index_group
146 .into_iter()
147 .next()
148 .map_or(vec![], |group| group.1.collect_vec());
149 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
150 group
151 .1
152 .map(|i| i - self.core.left.schema().len())
153 .collect_vec()
154 });
155 (left_index, right_index)
156 }
157
158 pub fn join_type(&self) -> JoinType {
160 self.core.join_type
161 }
162
163 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
165 self.core.eq_indexes()
166 }
167
168 pub fn output_indices(&self) -> &Vec<usize> {
170 &self.core.output_indices
171 }
172
173 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
175 Self::with_core(generic::Join {
176 output_indices,
177 ..self.core.clone()
178 })
179 }
180
181 pub fn clone_with_cond(&self, on: Condition) -> Self {
183 Self::with_core(generic::Join {
184 on,
185 ..self.core.clone()
186 })
187 }
188
189 pub fn is_left_join(&self) -> bool {
190 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
191 }
192
193 pub fn is_right_join(&self) -> bool {
194 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
195 }
196
197 pub fn is_full_out(&self) -> bool {
198 self.core.is_full_out()
199 }
200
201 pub fn is_asof_join(&self) -> bool {
202 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
203 }
204
205 pub fn output_indices_are_trivial(&self) -> bool {
206 self.output_indices() == &(0..self.internal_column_num()).collect_vec()
207 }
208
209 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
214 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
215 JoinType::LeftOuter => (false, true),
216 JoinType::RightOuter => (true, false),
217 JoinType::FullOuter => (true, true),
218 _ => return join_type,
219 };
220
221 for expr in &predicate.conjunctions {
222 if let ExprImpl::FunctionCall(func) = expr {
223 match func.func_type() {
224 ExprType::Equal
225 | ExprType::NotEqual
226 | ExprType::LessThan
227 | ExprType::LessThanOrEqual
228 | ExprType::GreaterThan
229 | ExprType::GreaterThanOrEqual => {
230 for input in func.inputs() {
231 if let ExprImpl::InputRef(input) = input {
232 let idx = input.index;
233 if idx < left_col_num {
234 gen_null_in_left = false;
235 } else {
236 gen_null_in_right = false;
237 }
238 }
239 }
240 }
241 _ => {}
242 };
243 }
244 }
245
246 match (gen_null_in_left, gen_null_in_right) {
247 (true, true) => JoinType::FullOuter,
248 (true, false) => JoinType::RightOuter,
249 (false, true) => JoinType::LeftOuter,
250 (false, false) => JoinType::Inner,
251 }
252 }
253
254 fn to_batch_lookup_join_with_index_selection(
258 &self,
259 predicate: EqJoinPredicate,
260 batch_join: generic::Join<BatchPlanRef>,
261 ) -> Result<Option<BatchLookupJoin>> {
262 match batch_join.join_type {
263 JoinType::Inner
264 | JoinType::LeftOuter
265 | JoinType::LeftSemi
266 | JoinType::LeftAnti
267 | JoinType::AsofInner
268 | JoinType::AsofLeftOuter => {}
269 _ => return Ok(None),
270 };
271
272 let right = self.right();
274 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
276 logical_scan
277 } else {
278 return Ok(None);
279 };
280
281 let mut result_plan = None;
282 if let Some(lookup_join) =
284 self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
285 {
286 result_plan = Some(lookup_join);
287 }
288
289 if self
290 .core
291 .ctx()
292 .session_ctx()
293 .config()
294 .enable_index_selection()
295 {
296 let indexes = logical_scan.table_indexes();
297 for index in indexes {
298 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
299 let index_scan: PlanRef = index_scan.into();
300 let that = self.clone_with_left_right(self.left(), index_scan.clone());
301 let mut new_batch_join = batch_join.clone();
302 new_batch_join.right =
303 index_scan.to_batch().expect("index scan failed to batch");
304
305 if let Some(lookup_join) =
307 that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
308 {
309 match &result_plan {
310 None => result_plan = Some(lookup_join),
311 Some(prev_lookup_join) => {
312 if prev_lookup_join.lookup_prefix_len()
314 < lookup_join.lookup_prefix_len()
315 {
316 result_plan = Some(lookup_join)
317 }
318 }
319 }
320 }
321 }
322 }
323 }
324
325 Ok(result_plan)
326 }
327
328 fn to_batch_lookup_join(
330 &self,
331 predicate: EqJoinPredicate,
332 logical_join: generic::Join<BatchPlanRef>,
333 ) -> Result<Option<BatchLookupJoin>> {
334 let logical_scan: &LogicalScan =
335 if let Some(logical_scan) = self.core.right.as_logical_scan() {
336 logical_scan
337 } else {
338 return Ok(None);
339 };
340 Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
341 }
342
343 pub fn gen_batch_lookup_join(
344 logical_scan: &LogicalScan,
345 predicate: EqJoinPredicate,
346 logical_join: generic::Join<BatchPlanRef>,
347 is_as_of: bool,
348 ) -> Result<Option<BatchLookupJoin>> {
349 match logical_join.join_type {
350 JoinType::Inner
351 | JoinType::LeftOuter
352 | JoinType::LeftSemi
353 | JoinType::LeftAnti
354 | JoinType::AsofInner
355 | JoinType::AsofLeftOuter => {}
356 _ => return Ok(None),
357 };
358
359 let table = logical_scan.table();
360 let output_column_ids = logical_scan.output_column_ids();
361
362 let order_col_ids = table.order_column_ids();
365 let dist_key = table.distribution_key.clone();
366 let mut dist_key_in_order_key_pos = vec![];
368 for d in dist_key {
369 let pos = table
370 .order_column_indices()
371 .position(|x| x == d)
372 .expect("dist_key must in order_key");
373 dist_key_in_order_key_pos.push(pos);
374 }
375 let shortest_prefix_len = dist_key_in_order_key_pos
377 .iter()
378 .max()
379 .map_or(0, |pos| pos + 1);
380
381 if shortest_prefix_len == 0 {
383 return Ok(None);
384 }
385
386 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
388 for order_col_id in order_col_ids {
389 let mut found = false;
390 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
391 if order_col_id == output_column_ids[eq_idx] {
392 reorder_idx.push(i);
393 found = true;
394 break;
395 }
396 }
397 if !found {
398 break;
399 }
400 }
401 if reorder_idx.len() < shortest_prefix_len {
402 return Ok(None);
403 }
404 let lookup_prefix_len = reorder_idx.len();
405 let predicate = predicate.reorder(&reorder_idx);
406
407 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
409 let o2r = if let Some(project_expr) = project_expr {
411 project_expr
412 .into_iter()
413 .map(|x| x.as_input_ref().unwrap().index)
414 .collect_vec()
415 } else {
416 (0..logical_scan.output_col_idx().len()).collect_vec()
417 };
418 let left_schema_len = logical_join.left.schema().len();
419
420 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
421 offset: left_schema_len,
422 mapping: o2r.clone(),
423 };
424
425 let new_eq_cond = predicate
426 .eq_cond()
427 .rewrite_expr(&mut join_predicate_rewriter);
428
429 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
430 offset: left_schema_len,
431 };
432
433 let new_other_cond = predicate
434 .other_cond()
435 .clone()
436 .rewrite_expr(&mut join_predicate_rewriter)
437 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
438
439 let new_join_on = new_eq_cond.and(new_other_cond);
440 let new_predicate = EqJoinPredicate::create(
441 left_schema_len,
442 new_scan.schema().len(),
443 new_join_on.clone(),
444 );
445
446 if !new_predicate.has_eq() {
449 return Ok(None);
450 }
451
452 let new_join_output_indices = logical_join
455 .output_indices
456 .iter()
457 .map(|&x| {
458 if x < left_schema_len {
459 x
460 } else {
461 o2r[x - left_schema_len] + left_schema_len
462 }
463 })
464 .collect_vec();
465
466 let new_scan_output_column_ids = new_scan.output_column_ids();
467 let as_of = new_scan.as_of.clone();
468 let new_logical_scan: LogicalScan = new_scan.into();
469
470 let new_logical_join = generic::Join::new(
472 logical_join.left,
473 new_logical_scan.to_batch()?,
474 new_join_on,
475 logical_join.join_type,
476 new_join_output_indices,
477 );
478
479 let asof_desc = is_as_of
480 .then(|| {
481 Self::get_inequality_desc_from_predicate(
482 predicate.other_cond().clone(),
483 left_schema_len,
484 )
485 })
486 .transpose()?;
487
488 Ok(Some(BatchLookupJoin::new(
489 new_logical_join,
490 new_predicate,
491 table.clone(),
492 new_scan_output_column_ids,
493 lookup_prefix_len,
494 false,
495 as_of,
496 asof_desc,
497 )))
498 }
499
500 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
501 self.core.decompose()
502 }
503}
504
505impl PlanTreeNodeBinary<Logical> for LogicalJoin {
506 fn left(&self) -> PlanRef {
507 self.core.left.clone()
508 }
509
510 fn right(&self) -> PlanRef {
511 self.core.right.clone()
512 }
513
514 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
515 Self::with_core(generic::Join {
516 left,
517 right,
518 ..self.core.clone()
519 })
520 }
521
522 fn rewrite_with_left_right(
523 &self,
524 left: PlanRef,
525 left_col_change: ColIndexMapping,
526 right: PlanRef,
527 right_col_change: ColIndexMapping,
528 ) -> (Self, ColIndexMapping) {
529 let (new_on, new_output_indices) = {
530 let (mut map, _) = left_col_change.clone().into_parts();
531 let (mut right_map, _) = right_col_change.clone().into_parts();
532 for i in right_map.iter_mut().flatten() {
533 *i += left.schema().len();
534 }
535 map.append(&mut right_map);
536 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
537
538 let new_output_indices = self
539 .output_indices()
540 .iter()
541 .map(|&i| mapping.map(i))
542 .collect::<Vec<_>>();
543 let new_on = self.on().clone().rewrite_expr(&mut mapping);
544 (new_on, new_output_indices)
545 };
546
547 let join = Self::with_output_indices(
548 left,
549 right,
550 self.join_type(),
551 new_on,
552 new_output_indices.clone(),
553 );
554
555 let new_i2o = ColIndexMapping::with_remaining_columns(
556 &new_output_indices,
557 join.internal_column_num(),
558 );
559
560 let old_o2i = self.core.o2i_col_mapping();
561
562 let old_o2l = old_o2i
563 .composite(&self.core.i2l_col_mapping())
564 .composite(&left_col_change);
565 let old_o2r = old_o2i
566 .composite(&self.core.i2r_col_mapping())
567 .composite(&right_col_change);
568 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
569 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
570
571 let out_col_change = old_o2l
572 .composite(&new_l2o)
573 .union(&old_o2r.composite(&new_r2o));
574 (join, out_col_change)
575 }
576}
577
578impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
579
580impl ColPrunable for LogicalJoin {
581 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
582 let required_cols = required_cols
584 .iter()
585 .map(|i| self.output_indices()[*i])
586 .collect_vec();
587 let left_len = self.left().schema().fields.len();
588
589 let total_len = self.left().schema().len() + self.right().schema().len();
590 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
591
592 required_cols.iter().for_each(|&i| {
593 if self.is_right_join() {
594 resized_required_cols.insert(left_len + i);
595 } else {
596 resized_required_cols.insert(i);
597 }
598 });
599
600 let mut visitor = CollectInputRef::new(resized_required_cols);
603 self.on().visit_expr(&mut visitor);
604 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
605
606 let mut left_required_cols = Vec::new();
607 let mut right_required_cols = Vec::new();
608 left_right_required_cols.iter().for_each(|&i| {
609 if i < left_len {
610 left_required_cols.push(i);
611 } else {
612 right_required_cols.push(i - left_len);
613 }
614 });
615
616 let mut on = self.on().clone();
617 let mut mapping =
618 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
619 on = on.rewrite_expr(&mut mapping);
620
621 let new_output_indices = {
622 let required_inputs_in_output = if self.is_left_join() {
623 &left_required_cols
624 } else if self.is_right_join() {
625 &right_required_cols
626 } else {
627 &left_right_required_cols
628 };
629
630 let mapping =
631 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
632 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
633 };
634
635 LogicalJoin::with_output_indices(
636 self.left().prune_col(&left_required_cols, ctx),
637 self.right().prune_col(&right_required_cols, ctx),
638 self.join_type(),
639 on,
640 new_output_indices,
641 )
642 .into()
643 }
644}
645
646impl ExprRewritable<Logical> for LogicalJoin {
647 fn has_rewritable_expr(&self) -> bool {
648 true
649 }
650
651 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
652 let mut core = self.core.clone();
653 core.rewrite_exprs(r);
654 Self {
655 base: self.base.clone_with_new_plan_id(),
656 core,
657 }
658 .into()
659 }
660}
661
662impl ExprVisitable for LogicalJoin {
663 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
664 self.core.visit_exprs(v);
665 }
666}
667
668fn derive_predicate_from_eq_condition(
686 expr: &ExprImpl,
687 eq_condition: &EqJoinPredicate,
688 col_num: usize,
689 expr_is_left: bool,
690) -> Option<ExprImpl> {
691 if expr.is_impure() {
692 return None;
693 }
694 let eq_indices = eq_condition
695 .eq_indexes_typed()
696 .iter()
697 .filter_map(|(l, r)| {
698 if l.return_type() != r.return_type() {
699 None
700 } else if expr_is_left {
701 Some(l.index())
702 } else {
703 Some(r.index())
704 }
705 })
706 .collect_vec();
707 if expr
708 .collect_input_refs(col_num)
709 .ones()
710 .any(|index| !eq_indices.contains(&index))
711 {
712 return None;
714 }
715 let other_side_mapping = if expr_is_left {
718 eq_condition.eq_indexes_typed().into_iter().collect()
719 } else {
720 eq_condition
721 .eq_indexes_typed()
722 .into_iter()
723 .map(|(x, y)| (y, x))
724 .collect()
725 };
726 struct InputRefsRewriter {
727 mapping: HashMap<InputRef, InputRef>,
728 }
729 impl ExprRewriter for InputRefsRewriter {
730 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
731 self.mapping[&input_ref].clone().into()
732 }
733 }
734 Some(
735 InputRefsRewriter {
736 mapping: other_side_mapping,
737 }
738 .rewrite_expr(expr.clone()),
739 )
740}
741
742struct LookupJoinPredicateRewriter {
744 offset: usize,
745 mapping: Vec<usize>,
746}
747impl ExprRewriter for LookupJoinPredicateRewriter {
748 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
749 if input_ref.index() < self.offset {
750 input_ref.into()
751 } else {
752 InputRef::new(
753 self.mapping[input_ref.index() - self.offset] + self.offset,
754 input_ref.return_type(),
755 )
756 .into()
757 }
758 }
759}
760
761struct LookupJoinScanPredicateRewriter {
763 offset: usize,
764}
765impl ExprRewriter for LookupJoinScanPredicateRewriter {
766 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
767 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
768 }
769}
770
771impl PredicatePushdown for LogicalJoin {
772 fn predicate_pushdown(
796 &self,
797 predicate: Condition,
798 ctx: &mut PredicatePushdownContext,
799 ) -> PlanRef {
800 let mut predicate = {
802 let mut mapping = self.core.o2i_col_mapping();
803 predicate.rewrite_expr(&mut mapping)
804 };
805
806 let left_col_num = self.left().schema().len();
807 let right_col_num = self.right().schema().len();
808 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
809
810 let push_down_temporal_predicate = !self.should_be_temporal_join();
811
812 let (left_from_filter, right_from_filter, on) = push_down_into_join(
813 &mut predicate,
814 left_col_num,
815 right_col_num,
816 join_type,
817 push_down_temporal_predicate,
818 );
819
820 let mut new_on = self.on().clone().and(on);
821 let (left_from_on, right_from_on) = push_down_join_condition(
822 &mut new_on,
823 left_col_num,
824 right_col_num,
825 join_type,
826 push_down_temporal_predicate,
827 );
828
829 let left_predicate = left_from_filter.and(left_from_on);
830 let right_predicate = right_from_filter.and(right_from_on);
831
832 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
834
835 let right_from_left = if matches!(
837 join_type,
838 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
839 ) {
840 Condition {
841 conjunctions: left_predicate
842 .conjunctions
843 .iter()
844 .filter_map(|expr| {
845 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
846 })
847 .collect(),
848 }
849 } else {
850 Condition::true_cond()
851 };
852
853 let left_from_right = if matches!(
855 join_type,
856 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
857 ) {
858 Condition {
859 conjunctions: right_predicate
860 .conjunctions
861 .iter()
862 .filter_map(|expr| {
863 derive_predicate_from_eq_condition(
864 expr,
865 &eq_condition,
866 right_col_num,
867 false,
868 )
869 })
870 .collect(),
871 }
872 } else {
873 Condition::true_cond()
874 };
875
876 let left_predicate = left_predicate.and(left_from_right);
877 let right_predicate = right_predicate.and(right_from_left);
878
879 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
880 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
881 let new_join = LogicalJoin::with_output_indices(
882 new_left,
883 new_right,
884 join_type,
885 new_on,
886 self.output_indices().clone(),
887 );
888
889 let mut mapping = self.core.i2o_col_mapping();
890 predicate = predicate.rewrite_expr(&mut mapping);
891 LogicalFilter::create(new_join.into(), predicate)
892 }
893}
894
895impl LogicalJoin {
896 fn get_stream_input_for_hash_join(
897 &self,
898 predicate: &EqJoinPredicate,
899 ctx: &mut ToStreamContext,
900 ) -> Result<(StreamPlanRef, StreamPlanRef)> {
901 use super::stream::prelude::*;
902
903 let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
904 let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
905
906 let logical_right = self
907 .right()
908 .try_better_locality(&rhs_join_key_idx)
909 .unwrap_or_else(|| self.right());
910 let mut right = logical_right.to_stream_with_dist_required(
911 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
912 ctx,
913 )?;
914 let logical_left = self
915 .left()
916 .try_better_locality(&lhs_join_key_idx)
917 .unwrap_or_else(|| self.left());
918
919 let r2l =
920 predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
921 let l2r =
922 predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
923 let mut left;
924 let right_dist = right.distribution();
925 match right_dist {
926 Distribution::HashShard(_) => {
927 let left_dist = r2l
928 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
929 left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
930 }
931 Distribution::UpstreamHashShard(_, _) => {
932 left = logical_left.to_stream_with_dist_required(
933 &RequiredDist::shard_by_key(
934 self.left().schema().len(),
935 &predicate.left_eq_indexes(),
936 ),
937 ctx,
938 )?;
939 let left_dist = left.distribution();
940 match left_dist {
941 Distribution::HashShard(_) => {
942 let right_dist = l2r.rewrite_required_distribution(
943 &RequiredDist::PhysicalDist(left_dist.clone()),
944 );
945 right = right_dist.streaming_enforce_if_not_satisfies(right)?
946 }
947 Distribution::UpstreamHashShard(_, _) => {
948 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
949 .streaming_enforce_if_not_satisfies(left)?;
950 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
951 .streaming_enforce_if_not_satisfies(right)?;
952 }
953 _ => unreachable!(),
954 }
955 }
956 _ => unreachable!(),
957 }
958 Ok((left, right))
959 }
960
961 fn to_stream_hash_join(
962 &self,
963 predicate: EqJoinPredicate,
964 ctx: &mut ToStreamContext,
965 ) -> Result<StreamPlanRef> {
966 use super::stream::prelude::*;
967
968 assert!(predicate.has_eq());
969 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
970
971 let core = self.core.clone_with_inputs(left, right);
972
973 let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
982
983 let force_filter_inside_join = self
984 .base
985 .ctx()
986 .session_ctx()
987 .config()
988 .streaming_force_filter_inside_join();
989
990 let pull_filter = self.join_type() == JoinType::Inner
991 && stream_hash_join.eq_join_predicate().has_non_eq()
992 && stream_hash_join.inequality_pairs().is_empty()
993 && (!force_filter_inside_join);
994 if pull_filter {
995 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
996
997 let mut core = core.clone();
998 core.output_indices = default_indices.clone();
999 let eq_cond = EqJoinPredicate::new(
1001 Condition::true_cond(),
1002 predicate.eq_keys().to_vec(),
1003 self.left().schema().len(),
1004 self.right().schema().len(),
1005 );
1006 core.on = eq_cond.eq_cond();
1007 let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
1008 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1009 let plan = StreamFilter::new(logical_filter).into();
1010 if self.output_indices() != &default_indices {
1011 let logical_project = generic::Project::with_mapping(
1012 plan,
1013 ColIndexMapping::with_remaining_columns(
1014 self.output_indices(),
1015 self.internal_column_num(),
1016 ),
1017 );
1018 Ok(StreamProject::new(logical_project).into())
1019 } else {
1020 Ok(plan)
1021 }
1022 } else {
1023 Ok(stream_hash_join.into())
1024 }
1025 }
1026
1027 fn should_be_temporal_join(&self) -> bool {
1028 let right = self.right();
1029 if let Some(logical_scan) = right.as_logical_scan() {
1030 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1031 } else {
1032 false
1033 }
1034 }
1035
1036 fn to_stream_temporal_join_with_index_selection(
1037 &self,
1038 predicate: EqJoinPredicate,
1039 ctx: &mut ToStreamContext,
1040 ) -> Result<StreamPlanRef> {
1041 let right = self.right();
1043 let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1045
1046 let mut result_plan: Result<StreamTemporalJoin> =
1048 self.to_stream_temporal_join(predicate.clone(), ctx);
1049 if let Ok(temporal_join) = &result_plan
1051 && temporal_join.eq_join_predicate().eq_indexes().len()
1052 == logical_scan.primary_key().len()
1053 {
1054 return result_plan.map(|x| x.into());
1055 }
1056 if self
1057 .core
1058 .ctx()
1059 .session_ctx()
1060 .config()
1061 .enable_index_selection()
1062 {
1063 let indexes = logical_scan.table_indexes();
1064 for index in indexes {
1065 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1067 let index_scan: PlanRef = index_scan.into();
1068 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1069 if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx)
1070 {
1071 match &result_plan {
1072 Err(_) => result_plan = Ok(temporal_join),
1073 Ok(prev_temporal_join) => {
1074 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1076 < temporal_join.eq_join_predicate().eq_indexes().len()
1077 {
1078 result_plan = Ok(temporal_join)
1079 }
1080 }
1081 }
1082 }
1083 }
1084 }
1085 }
1086
1087 result_plan.map(|x| x.into())
1088 }
1089
1090 fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1091 let Some(logical_scan) = right.as_logical_scan() else {
1092 return Err(RwError::from(ErrorCode::NotSupported(
1093 "Temporal join requires a table scan as its lookup table".into(),
1094 "Please provide a table scan".into(),
1095 )));
1096 };
1097
1098 if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1099 return Err(RwError::from(ErrorCode::NotSupported(
1100 "Temporal join requires a table defined as temporal table".into(),
1101 "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1102 )));
1103 }
1104 Ok(logical_scan)
1105 }
1106
1107 fn temporal_join_scan_predicate_pull_up(
1108 logical_scan: &LogicalScan,
1109 predicate: EqJoinPredicate,
1110 output_indices: &[usize],
1111 left_schema_len: usize,
1112 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1113 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1115 let o2r = if let Some(project_expr) = project_expr {
1117 project_expr
1118 .into_iter()
1119 .map(|x| x.as_input_ref().unwrap().index)
1120 .collect_vec()
1121 } else {
1122 (0..logical_scan.output_col_idx().len()).collect_vec()
1123 };
1124 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1125 offset: left_schema_len,
1126 mapping: o2r.clone(),
1127 };
1128
1129 let new_eq_cond = predicate
1130 .eq_cond()
1131 .rewrite_expr(&mut join_predicate_rewriter);
1132
1133 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1134 offset: left_schema_len,
1135 };
1136
1137 let new_other_cond = predicate
1138 .other_cond()
1139 .clone()
1140 .rewrite_expr(&mut join_predicate_rewriter)
1141 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1142
1143 let new_join_on = new_eq_cond.and(new_other_cond);
1144
1145 let new_predicate = EqJoinPredicate::create(
1146 left_schema_len,
1147 new_scan.schema().len(),
1148 new_join_on.clone(),
1149 );
1150
1151 let new_join_output_indices = output_indices
1154 .iter()
1155 .map(|&x| {
1156 if x < left_schema_len {
1157 x
1158 } else {
1159 o2r[x - left_schema_len] + left_schema_len
1160 }
1161 })
1162 .collect_vec();
1163
1164 if new_scan.cross_database() {
1166 return Err(RwError::from(ErrorCode::NotSupported(
1167 "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1168 "Please ensure both tables are in the same database".into(),
1169 )));
1170 }
1171 let new_stream_table_scan =
1172 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1173 Ok((
1174 new_stream_table_scan,
1175 new_predicate,
1176 new_join_on,
1177 new_join_output_indices,
1178 ))
1179 }
1180
1181 fn to_stream_temporal_join(
1182 &self,
1183 predicate: EqJoinPredicate,
1184 ctx: &mut ToStreamContext,
1185 ) -> Result<StreamTemporalJoin> {
1186 use super::stream::prelude::*;
1187
1188 assert!(predicate.has_eq());
1189
1190 let right = self.right();
1191
1192 let logical_scan = Self::check_temporal_rhs(&right)?;
1193
1194 let table = logical_scan.table();
1195 let output_column_ids = logical_scan.output_column_ids();
1196
1197 let order_col_ids = table.order_column_ids();
1200 let dist_key = table.distribution_key.clone();
1201
1202 let mut dist_key_in_order_key_pos = vec![];
1203 for d in dist_key {
1204 let pos = table
1205 .order_column_indices()
1206 .position(|x| x == d)
1207 .expect("dist_key must in order_key");
1208 dist_key_in_order_key_pos.push(pos);
1209 }
1210 let shortest_prefix_len = dist_key_in_order_key_pos
1212 .iter()
1213 .max()
1214 .map_or(0, |pos| pos + 1);
1215
1216 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1218 for order_col_id in order_col_ids {
1219 let mut found = false;
1220 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1221 if order_col_id == output_column_ids[eq_idx] {
1222 reorder_idx.push(i);
1223 found = true;
1224 break;
1225 }
1226 }
1227 if !found {
1228 break;
1229 }
1230 }
1231 if reorder_idx.len() < shortest_prefix_len {
1232 return Err(RwError::from(ErrorCode::NotSupported(
1234 "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1235 "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1236 )));
1237 }
1238 let lookup_prefix_len = reorder_idx.len();
1239 let predicate = predicate.reorder(&reorder_idx);
1240
1241 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1242 RequiredDist::single()
1243 } else {
1244 let left_eq_indexes = predicate.left_eq_indexes();
1245 let left_dist_key = dist_key_in_order_key_pos
1246 .iter()
1247 .map(|pos| left_eq_indexes[*pos])
1248 .collect_vec();
1249
1250 RequiredDist::hash_shard(&left_dist_key)
1251 };
1252
1253 let lhs_join_key_idx = predicate
1254 .eq_indexes()
1255 .into_iter()
1256 .map(|(l, _)| l)
1257 .collect_vec();
1258 let logical_left = self
1259 .left()
1260 .try_better_locality(&lhs_join_key_idx)
1261 .unwrap_or_else(|| self.left());
1262 let left = logical_left.to_stream(ctx)?;
1263 let left = required_dist.stream_enforce(left);
1265
1266 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1267 Self::temporal_join_scan_predicate_pull_up(
1268 logical_scan,
1269 predicate,
1270 self.output_indices(),
1271 self.left().schema().len(),
1272 )?;
1273
1274 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1275 if !new_predicate.has_eq() {
1276 return Err(RwError::from(ErrorCode::NotSupported(
1277 "Temporal join requires a non trivial join condition".into(),
1278 "Please remove the false condition of the join".into(),
1279 )));
1280 }
1281
1282 let new_logical_join = generic::Join::new(
1284 left,
1285 right,
1286 new_join_on,
1287 self.join_type(),
1288 new_join_output_indices,
1289 );
1290
1291 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1292
1293 StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1294 }
1295
1296 fn to_stream_nested_loop_temporal_join(
1297 &self,
1298 predicate: EqJoinPredicate,
1299 ctx: &mut ToStreamContext,
1300 ) -> Result<StreamPlanRef> {
1301 use super::stream::prelude::*;
1302 assert!(!predicate.has_eq());
1303
1304 let left = self.left().to_stream_with_dist_required(
1305 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1306 ctx,
1307 )?;
1308 assert!(left.as_stream_exchange().is_some());
1309
1310 if self.join_type() != JoinType::Inner {
1311 return Err(RwError::from(ErrorCode::NotSupported(
1312 "Temporal join requires an inner join".into(),
1313 "Please use an inner join".into(),
1314 )));
1315 }
1316
1317 if !left.append_only() {
1318 return Err(RwError::from(ErrorCode::NotSupported(
1319 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1320 "Please ensure the left hash side is append only".into(),
1321 )));
1322 }
1323
1324 let right = self.right();
1325 let logical_scan = Self::check_temporal_rhs(&right)?;
1326
1327 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1328 Self::temporal_join_scan_predicate_pull_up(
1329 logical_scan,
1330 predicate,
1331 self.output_indices(),
1332 self.left().schema().len(),
1333 )?;
1334
1335 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1336
1337 let new_logical_join = generic::Join::new(
1339 left,
1340 right,
1341 new_join_on,
1342 self.join_type(),
1343 new_join_output_indices,
1344 );
1345
1346 Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1347 }
1348
1349 fn to_stream_dynamic_filter(
1350 &self,
1351 predicate: Condition,
1352 ctx: &mut ToStreamContext,
1353 ) -> Result<Option<StreamPlanRef>> {
1354 use super::stream::prelude::*;
1355
1356 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1362 return Ok(None);
1363 }
1364
1365 if !self.right().max_one_row() {
1367 return Ok(None);
1368 }
1369 if self.right().schema().len() != 1 {
1370 return Ok(None);
1371 }
1372
1373 if predicate.conjunctions.len() > 1 {
1375 return Ok(None);
1376 }
1377 let expr: ExprImpl = predicate.into();
1378 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1379 Some(v) => v,
1380 None => return Ok(None),
1381 };
1382
1383 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1384 && right_ref.index == self.left().schema().len() ;
1385 if !condition_cross_inputs {
1386 return Ok(None);
1388 }
1389
1390 if self.left().schema().fields()[left_ref.index].data_type
1392 != self.right().schema().fields()[0].data_type
1393 {
1394 return Ok(None);
1395 }
1396
1397 let all_output_from_left = self
1399 .output_indices()
1400 .iter()
1401 .all(|i| *i < self.left().schema().len());
1402 if !all_output_from_left {
1403 return Ok(None);
1404 }
1405
1406 let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1407 let right = self.right().to_stream_with_dist_required(
1408 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1409 ctx,
1410 )?;
1411
1412 assert!(right.as_stream_exchange().is_some());
1413 assert_eq!(
1414 *right.inputs().iter().exactly_one().unwrap().distribution(),
1415 Distribution::Single
1416 );
1417
1418 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1419 let plan = StreamDynamicFilter::new(core)?.into();
1420 if self
1422 .output_indices()
1423 .iter()
1424 .copied()
1425 .ne(0..self.left().schema().len())
1426 {
1427 let logical_project = generic::Project::with_mapping(
1430 plan,
1431 ColIndexMapping::with_remaining_columns(
1432 self.output_indices(),
1433 self.left().schema().len(),
1434 ),
1435 );
1436 Ok(Some(StreamProject::new(logical_project).into()))
1437 } else {
1438 Ok(Some(plan))
1439 }
1440 }
1441
1442 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1443 let predicate = EqJoinPredicate::create(
1444 self.left().schema().len(),
1445 self.right().schema().len(),
1446 self.on().clone(),
1447 );
1448 assert!(predicate.has_eq());
1449
1450 let join = self
1451 .core
1452 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1453
1454 Ok(self
1455 .to_batch_lookup_join(predicate, join)?
1456 .expect("Fail to convert to lookup join")
1457 .into())
1458 }
1459
1460 fn to_stream_asof_join(
1461 &self,
1462 predicate: EqJoinPredicate,
1463 ctx: &mut ToStreamContext,
1464 ) -> Result<StreamPlanRef> {
1465 use super::stream::prelude::*;
1466
1467 if predicate.eq_keys().is_empty() {
1468 return Err(ErrorCode::InvalidInputSyntax(
1469 "AsOf join requires at least 1 equal condition".to_owned(),
1470 )
1471 .into());
1472 }
1473
1474 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1475 let left_len = left.schema().len();
1476 let core = self.core.clone_with_inputs(left, right);
1477
1478 let inequality_desc =
1479 Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1480
1481 Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1482 }
1483
1484 fn to_batch_hash_join(
1486 &self,
1487 logical_join: generic::Join<BatchPlanRef>,
1488 predicate: EqJoinPredicate,
1489 ) -> Result<BatchPlanRef> {
1490 use super::batch::prelude::*;
1491
1492 let left_schema_len = logical_join.left.schema().len();
1493 let asof_desc = self
1494 .is_asof_join()
1495 .then(|| {
1496 Self::get_inequality_desc_from_predicate(
1497 predicate.other_cond().clone(),
1498 left_schema_len,
1499 )
1500 })
1501 .transpose()?;
1502
1503 let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1504 Ok(batch_join.into())
1505 }
1506
1507 pub fn get_inequality_desc_from_predicate(
1508 predicate: Condition,
1509 left_input_len: usize,
1510 ) -> Result<AsOfJoinDesc> {
1511 let expr: ExprImpl = predicate.into();
1512 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1513 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1514 {
1515 Ok(AsOfJoinDesc {
1516 left_idx: left_input_ref.index() as u32,
1517 right_idx: (right_input_ref.index() - left_input_len) as u32,
1518 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1519 })
1520 } else {
1521 bail!("inequal condition from the same side should be push down in optimizer");
1522 }
1523 } else {
1524 Err(ErrorCode::InvalidInputSyntax(
1525 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1526 )
1527 .into())
1528 }
1529 }
1530
1531 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1532 match expr_type {
1533 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1534 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1535 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1536 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1537 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1538 "Invalid comparison type: {}",
1539 expr_type.as_str_name()
1540 ))
1541 .into()),
1542 }
1543 }
1544}
1545
1546impl ToBatch for LogicalJoin {
1547 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1548 let predicate = EqJoinPredicate::create(
1549 self.left().schema().len(),
1550 self.right().schema().len(),
1551 self.on().clone(),
1552 );
1553
1554 let batch_join = self
1555 .core
1556 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1557
1558 let ctx = self.base.ctx();
1559 let config = ctx.session_ctx().config();
1560
1561 if predicate.has_eq() {
1562 if !predicate.eq_keys_are_type_aligned() {
1563 return Err(ErrorCode::InternalError(format!(
1564 "Join eq keys are not aligned for predicate: {predicate:?}"
1565 ))
1566 .into());
1567 }
1568 if config.batch_enable_lookup_join()
1569 && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1570 predicate.clone(),
1571 batch_join.clone(),
1572 )?
1573 {
1574 return Ok(lookup_join.into());
1575 }
1576 self.to_batch_hash_join(batch_join, predicate)
1577 } else if self.is_asof_join() {
1578 Err(ErrorCode::InvalidInputSyntax(
1579 "AsOf join requires at least 1 equal condition".to_owned(),
1580 )
1581 .into())
1582 } else {
1583 Ok(BatchNestedLoopJoin::new(batch_join).into())
1585 }
1586 }
1587}
1588
1589impl ToStream for LogicalJoin {
1590 fn to_stream(
1591 &self,
1592 ctx: &mut ToStreamContext,
1593 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1594 if self
1595 .on()
1596 .conjunctions
1597 .iter()
1598 .any(|cond| cond.count_nows() > 0)
1599 {
1600 return Err(ErrorCode::NotSupported(
1601 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1602 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1603 }
1604
1605 let predicate = EqJoinPredicate::create(
1606 self.left().schema().len(),
1607 self.right().schema().len(),
1608 self.on().clone(),
1609 );
1610
1611 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1612 self.to_stream_asof_join(predicate, ctx)
1613 } else if predicate.has_eq() {
1614 if !predicate.eq_keys_are_type_aligned() {
1615 return Err(ErrorCode::InternalError(format!(
1616 "Join eq keys are not aligned for predicate: {predicate:?}"
1617 ))
1618 .into());
1619 }
1620
1621 if self.should_be_temporal_join() {
1622 self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1623 } else {
1624 self.to_stream_hash_join(predicate, ctx)
1625 }
1626 } else if self.should_be_temporal_join() {
1627 self.to_stream_nested_loop_temporal_join(predicate, ctx)
1628 } else if let Some(dynamic_filter) =
1629 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1630 {
1631 Ok(dynamic_filter)
1632 } else {
1633 Err(RwError::from(ErrorCode::NotSupported(
1634 "streaming nested-loop join".to_owned(),
1635 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1636 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1637 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1638 )))
1639 }
1640 }
1641
1642 fn logical_rewrite_for_stream(
1643 &self,
1644 ctx: &mut RewriteStreamContext,
1645 ) -> Result<(PlanRef, ColIndexMapping)> {
1646 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1647 let left_len = left.schema().len();
1648 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1649 let (join, out_col_change) = self.rewrite_with_left_right(
1650 left.clone(),
1651 left_col_change,
1652 right.clone(),
1653 right_col_change,
1654 );
1655
1656 let mapping = ColIndexMapping::with_remaining_columns(
1657 join.output_indices(),
1658 join.internal_column_num(),
1659 );
1660
1661 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1662 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1663
1664 let mut left_to_add = left
1666 .expect_stream_key()
1667 .iter()
1668 .cloned()
1669 .filter(|i| l2o.try_map(*i).is_none())
1670 .collect_vec();
1671
1672 let mut right_to_add = right
1673 .expect_stream_key()
1674 .iter()
1675 .filter(|&&i| r2o.try_map(i).is_none())
1676 .map(|&i| i + left_len)
1677 .collect_vec();
1678
1679 let right_len = right.schema().len();
1682 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1683
1684 let either_or_both = self.core.add_which_join_key_to_pk();
1685
1686 for (lk, rk) in eq_predicate.eq_indexes() {
1687 match either_or_both {
1688 EitherOrBoth::Left(_) => {
1689 if l2o.try_map(lk).is_none() {
1690 left_to_add.push(lk);
1691 }
1692 }
1693 EitherOrBoth::Right(_) => {
1694 if r2o.try_map(rk).is_none() {
1695 right_to_add.push(rk + left_len)
1696 }
1697 }
1698 EitherOrBoth::Both(_, _) => {
1699 if l2o.try_map(lk).is_none() {
1700 left_to_add.push(lk);
1701 }
1702 if r2o.try_map(rk).is_none() {
1703 right_to_add.push(rk + left_len)
1704 }
1705 }
1706 };
1707 }
1708 let left_to_add = left_to_add.into_iter().unique();
1709 let right_to_add = right_to_add.into_iter().unique();
1710 let mut new_output_indices = join.output_indices().clone();
1713 if !join.is_right_join() {
1714 new_output_indices.extend(left_to_add);
1715 }
1716 if !join.is_left_join() {
1717 new_output_indices.extend(right_to_add);
1718 }
1719
1720 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1721
1722 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1723 let l2o = join_with_pk
1726 .core
1727 .l2i_col_mapping()
1728 .composite(&join_with_pk.core.i2o_col_mapping());
1729 let r2o = join_with_pk
1730 .core
1731 .r2i_col_mapping()
1732 .composite(&join_with_pk.core.i2o_col_mapping());
1733 let left_right_stream_keys = join_with_pk
1734 .left()
1735 .expect_stream_key()
1736 .iter()
1737 .map(|i| l2o.map(*i))
1738 .chain(
1739 join_with_pk
1740 .right()
1741 .expect_stream_key()
1742 .iter()
1743 .map(|i| r2o.map(*i)),
1744 )
1745 .collect_vec();
1746 let plan: PlanRef = join_with_pk.into();
1747 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1748 } else {
1749 join_with_pk.into()
1750 };
1751
1752 Ok((plan, out_col_change))
1754 }
1755
1756 fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1757 let mut ctx = ToStreamContext::new(false);
1758 if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1760 let o2i_mapping = self.core.o2i_col_mapping();
1762 let left_input_columns = columns
1763 .iter()
1764 .map(|&col| o2i_mapping.try_map(col))
1765 .collect::<Option<Vec<usize>>>()?;
1766 if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1767 return Some(
1768 self.clone_with_left_right(better_left_plan, self.right())
1769 .into(),
1770 );
1771 }
1772 }
1773 None
1774 }
1775}
1776
1777#[cfg(test)]
1778mod tests {
1779
1780 use std::collections::HashSet;
1781
1782 use risingwave_common::catalog::{Field, Schema};
1783 use risingwave_common::types::{DataType, Datum};
1784 use risingwave_pb::expr::expr_node::Type;
1785
1786 use super::*;
1787 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1788 use crate::optimizer::optimizer_context::OptimizerContext;
1789 use crate::optimizer::plan_node::LogicalValues;
1790 use crate::optimizer::property::FunctionalDependency;
1791
1792 #[tokio::test]
1806 async fn test_prune_join() {
1807 let ty = DataType::Int32;
1808 let ctx = OptimizerContext::mock().await;
1809 let fields: Vec<Field> = (1..7)
1810 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1811 .collect();
1812 let left = LogicalValues::new(
1813 vec![],
1814 Schema {
1815 fields: fields[0..3].to_vec(),
1816 },
1817 ctx.clone(),
1818 );
1819 let right = LogicalValues::new(
1820 vec![],
1821 Schema {
1822 fields: fields[3..6].to_vec(),
1823 },
1824 ctx,
1825 );
1826 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1827 FunctionCall::new(
1828 Type::Equal,
1829 vec![
1830 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1831 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1832 ],
1833 )
1834 .unwrap(),
1835 ));
1836 let join_type = JoinType::Inner;
1837 let join: PlanRef = LogicalJoin::new(
1838 left.into(),
1839 right.into(),
1840 join_type,
1841 Condition::with_expr(on),
1842 )
1843 .into();
1844
1845 let required_cols = vec![2, 3];
1847 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1848
1849 let join = plan.as_logical_join().unwrap();
1851 assert_eq!(join.schema().fields().len(), 2);
1852 assert_eq!(join.schema().fields()[0], fields[2]);
1853 assert_eq!(join.schema().fields()[1], fields[3]);
1854
1855 let expr: ExprImpl = join.on().clone().into();
1856 let call = expr.as_function_call().unwrap();
1857 assert_eq_input_ref!(&call.inputs()[0], 0);
1858 assert_eq_input_ref!(&call.inputs()[1], 2);
1859
1860 let left = join.left();
1861 let left = left.as_logical_values().unwrap();
1862 assert_eq!(left.schema().fields(), &fields[1..3]);
1863 let right = join.right();
1864 let right = right.as_logical_values().unwrap();
1865 assert_eq!(right.schema().fields(), &fields[3..4]);
1866 }
1867
1868 #[tokio::test]
1870 async fn test_prune_semi_join() {
1871 let ty = DataType::Int32;
1872 let ctx = OptimizerContext::mock().await;
1873 let fields: Vec<Field> = (1..7)
1874 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1875 .collect();
1876 let left = LogicalValues::new(
1877 vec![],
1878 Schema {
1879 fields: fields[0..3].to_vec(),
1880 },
1881 ctx.clone(),
1882 );
1883 let right = LogicalValues::new(
1884 vec![],
1885 Schema {
1886 fields: fields[3..6].to_vec(),
1887 },
1888 ctx,
1889 );
1890 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1891 FunctionCall::new(
1892 Type::Equal,
1893 vec![
1894 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1895 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1896 ],
1897 )
1898 .unwrap(),
1899 ));
1900 for join_type in [
1901 JoinType::LeftSemi,
1902 JoinType::RightSemi,
1903 JoinType::LeftAnti,
1904 JoinType::RightAnti,
1905 ] {
1906 let join = LogicalJoin::new(
1907 left.clone().into(),
1908 right.clone().into(),
1909 join_type,
1910 Condition::with_expr(on.clone()),
1911 );
1912
1913 let offset = if join.is_right_join() { 3 } else { 0 };
1914 let join: PlanRef = join.into();
1915 let required_cols = vec![0];
1917 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1919 let as_plan = plan.as_logical_join().unwrap();
1920 assert_eq!(as_plan.schema().fields().len(), 1);
1922 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1923
1924 let required_cols = vec![0, 1, 2];
1926 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1928 let as_plan = plan.as_logical_join().unwrap();
1929 assert_eq!(as_plan.schema().fields().len(), 3);
1931 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1932 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1933 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1934 }
1935 }
1936
1937 #[tokio::test]
1950 async fn test_prune_join_no_project() {
1951 let ty = DataType::Int32;
1952 let ctx = OptimizerContext::mock().await;
1953 let fields: Vec<Field> = (1..7)
1954 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1955 .collect();
1956 let left = LogicalValues::new(
1957 vec![],
1958 Schema {
1959 fields: fields[0..3].to_vec(),
1960 },
1961 ctx.clone(),
1962 );
1963 let right = LogicalValues::new(
1964 vec![],
1965 Schema {
1966 fields: fields[3..6].to_vec(),
1967 },
1968 ctx,
1969 );
1970 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1971 FunctionCall::new(
1972 Type::Equal,
1973 vec![
1974 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1975 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1976 ],
1977 )
1978 .unwrap(),
1979 ));
1980 let join_type = JoinType::Inner;
1981 let join: PlanRef = LogicalJoin::new(
1982 left.into(),
1983 right.into(),
1984 join_type,
1985 Condition::with_expr(on),
1986 )
1987 .into();
1988
1989 let required_cols = vec![1, 3];
1991 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1992
1993 let join = plan.as_logical_join().unwrap();
1995 assert_eq!(join.schema().fields().len(), 2);
1996 assert_eq!(join.schema().fields()[0], fields[1]);
1997 assert_eq!(join.schema().fields()[1], fields[3]);
1998
1999 let expr: ExprImpl = join.on().clone().into();
2000 let call = expr.as_function_call().unwrap();
2001 assert_eq_input_ref!(&call.inputs()[0], 0);
2002 assert_eq_input_ref!(&call.inputs()[1], 1);
2003
2004 let left = join.left();
2005 let left = left.as_logical_values().unwrap();
2006 assert_eq!(left.schema().fields(), &fields[1..2]);
2007 let right = join.right();
2008 let right = right.as_logical_values().unwrap();
2009 assert_eq!(right.schema().fields(), &fields[3..4]);
2010 }
2011
2012 #[tokio::test]
2026 async fn test_join_to_batch() {
2027 let ctx = OptimizerContext::mock().await;
2028 let fields: Vec<Field> = (1..7)
2029 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2030 .collect();
2031 let left = LogicalValues::new(
2032 vec![],
2033 Schema {
2034 fields: fields[0..3].to_vec(),
2035 },
2036 ctx.clone(),
2037 );
2038 let right = LogicalValues::new(
2039 vec![],
2040 Schema {
2041 fields: fields[3..6].to_vec(),
2042 },
2043 ctx,
2044 );
2045
2046 fn input_ref(i: usize) -> ExprImpl {
2047 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2048 }
2049 let eq_cond = ExprImpl::FunctionCall(Box::new(
2050 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2051 ));
2052 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2053 FunctionCall::new(
2054 Type::Equal,
2055 vec![
2056 input_ref(2),
2057 ExprImpl::Literal(Box::new(Literal::new(
2058 Datum::Some(42_i32.into()),
2059 DataType::Int32,
2060 ))),
2061 ],
2062 )
2063 .unwrap(),
2064 ));
2065 let on_cond = ExprImpl::FunctionCall(Box::new(
2067 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2068 ));
2069
2070 let join_type = JoinType::Inner;
2071 let logical_join = LogicalJoin::new(
2072 left.into(),
2073 right.into(),
2074 join_type,
2075 Condition::with_expr(on_cond),
2076 );
2077
2078 let result = logical_join.to_batch().unwrap();
2080
2081 let hash_join = result.as_batch_hash_join().unwrap();
2083 assert_eq!(
2084 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2085 eq_cond
2086 );
2087 assert_eq!(
2088 *hash_join
2089 .eq_join_predicate()
2090 .non_eq_cond()
2091 .conjunctions
2092 .first()
2093 .unwrap(),
2094 non_eq_cond
2095 );
2096 }
2097
2098 #[tokio::test]
2111 #[ignore] async fn test_join_to_stream() {
2114 }
2182 #[tokio::test]
2196 async fn test_join_column_prune_with_order_required() {
2197 let ty = DataType::Int32;
2198 let ctx = OptimizerContext::mock().await;
2199 let fields: Vec<Field> = (1..7)
2200 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2201 .collect();
2202 let left = LogicalValues::new(
2203 vec![],
2204 Schema {
2205 fields: fields[0..3].to_vec(),
2206 },
2207 ctx.clone(),
2208 );
2209 let right = LogicalValues::new(
2210 vec![],
2211 Schema {
2212 fields: fields[3..6].to_vec(),
2213 },
2214 ctx,
2215 );
2216 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2217 FunctionCall::new(
2218 Type::Equal,
2219 vec![
2220 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2221 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2222 ],
2223 )
2224 .unwrap(),
2225 ));
2226 let join_type = JoinType::Inner;
2227 let join: PlanRef = LogicalJoin::new(
2228 left.into(),
2229 right.into(),
2230 join_type,
2231 Condition::with_expr(on),
2232 )
2233 .into();
2234
2235 let required_cols = vec![3, 2];
2237 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2238
2239 let join = plan.as_logical_join().unwrap();
2241 assert_eq!(join.schema().fields().len(), 2);
2242 assert_eq!(join.schema().fields()[0], fields[3]);
2243 assert_eq!(join.schema().fields()[1], fields[2]);
2244
2245 let expr: ExprImpl = join.on().clone().into();
2246 let call = expr.as_function_call().unwrap();
2247 assert_eq_input_ref!(&call.inputs()[0], 0);
2248 assert_eq_input_ref!(&call.inputs()[1], 2);
2249
2250 let left = join.left();
2251 let left = left.as_logical_values().unwrap();
2252 assert_eq!(left.schema().fields(), &fields[1..3]);
2253 let right = join.right();
2254 let right = right.as_logical_values().unwrap();
2255 assert_eq!(right.schema().fields(), &fields[3..4]);
2256 }
2257
2258 #[tokio::test]
2259 async fn fd_derivation_inner_outer_join() {
2260 let ctx = OptimizerContext::mock().await;
2283 let left = {
2284 let fields: Vec<Field> = vec![
2285 Field::with_name(DataType::Int32, "l0"),
2286 Field::with_name(DataType::Int32, "l1"),
2287 ];
2288 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2289 values
2291 .base
2292 .functional_dependency_mut()
2293 .add_functional_dependency_by_column_indices(&[0], &[1]);
2294 values
2295 };
2296 let right = {
2297 let fields: Vec<Field> = vec![
2298 Field::with_name(DataType::Int32, "r0"),
2299 Field::with_name(DataType::Int32, "r1"),
2300 Field::with_name(DataType::Int32, "r2"),
2301 ];
2302 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2303 values
2305 .base
2306 .functional_dependency_mut()
2307 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2308 values
2309 };
2310 let on: ExprImpl = FunctionCall::new(
2312 Type::And,
2313 vec![
2314 FunctionCall::new(
2315 Type::Equal,
2316 vec![
2317 InputRef::new(0, DataType::Int32).into(),
2318 ExprImpl::literal_int(0),
2319 ],
2320 )
2321 .unwrap()
2322 .into(),
2323 FunctionCall::new(
2324 Type::Equal,
2325 vec![
2326 InputRef::new(1, DataType::Int32).into(),
2327 InputRef::new(3, DataType::Int32).into(),
2328 ],
2329 )
2330 .unwrap()
2331 .into(),
2332 ],
2333 )
2334 .unwrap()
2335 .into();
2336 let expected_fd_set = [
2337 (
2338 JoinType::Inner,
2339 [
2340 FunctionalDependency::with_indices(5, &[0], &[1]),
2342 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2344 FunctionalDependency::with_indices(5, &[], &[0]),
2346 FunctionalDependency::with_indices(5, &[1], &[3]),
2348 FunctionalDependency::with_indices(5, &[3], &[1]),
2349 ]
2350 .into_iter()
2351 .collect::<HashSet<_>>(),
2352 ),
2353 (JoinType::FullOuter, HashSet::new()),
2354 (
2355 JoinType::RightOuter,
2356 [
2357 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2359 ]
2360 .into_iter()
2361 .collect::<HashSet<_>>(),
2362 ),
2363 (
2364 JoinType::LeftOuter,
2365 [
2366 FunctionalDependency::with_indices(5, &[0], &[1]),
2368 ]
2369 .into_iter()
2370 .collect::<HashSet<_>>(),
2371 ),
2372 (
2373 JoinType::LeftSemi,
2374 [
2375 FunctionalDependency::with_indices(2, &[0], &[1]),
2377 ]
2378 .into_iter()
2379 .collect::<HashSet<_>>(),
2380 ),
2381 (
2382 JoinType::LeftAnti,
2383 [
2384 FunctionalDependency::with_indices(2, &[0], &[1]),
2386 ]
2387 .into_iter()
2388 .collect::<HashSet<_>>(),
2389 ),
2390 (
2391 JoinType::RightSemi,
2392 [
2393 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2395 ]
2396 .into_iter()
2397 .collect::<HashSet<_>>(),
2398 ),
2399 (
2400 JoinType::RightAnti,
2401 [
2402 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2404 ]
2405 .into_iter()
2406 .collect::<HashSet<_>>(),
2407 ),
2408 ];
2409
2410 for (join_type, expected_res) in expected_fd_set {
2411 let join = LogicalJoin::new(
2412 left.clone().into(),
2413 right.clone().into(),
2414 join_type,
2415 Condition::with_expr(on.clone()),
2416 );
2417 let fd_set = join
2418 .functional_dependency()
2419 .as_dependencies()
2420 .iter()
2421 .cloned()
2422 .collect::<HashSet<_>>();
2423 assert_eq!(fd_set, expected_res);
2424 }
2425 }
2426}