1use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31 BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
32 PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject, ToBatch,
33 ToStream, generic, try_enforce_locality_requirement,
34};
35use crate::error::{ErrorCode, Result, RwError};
36use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
37use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
38use crate::optimizer::plan_node::generic::DynamicFilter;
39use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
40use crate::optimizer::plan_node::utils::IndicesDisplay;
41use crate::optimizer::plan_node::{
42 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
43 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
44 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
45};
46use crate::optimizer::plan_visitor::LogicalCardinalityExt;
47use crate::optimizer::property::{Distribution, RequiredDist};
48use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct LogicalJoin {
58 pub base: PlanBase<Logical>,
59 core: generic::Join<PlanRef>,
60}
61
62impl Distill for LogicalJoin {
63 fn distill<'a>(&self) -> XmlNode<'a> {
64 let verbose = self.base.ctx().is_explain_verbose();
65 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
66 vec.push(("type", Pretty::debug(&self.join_type())));
67
68 let concat_schema = self.core.concat_schema();
69 let cond = Pretty::debug(&ConditionDisplay {
70 condition: self.on(),
71 input_schema: &concat_schema,
72 });
73 vec.push(("on", cond));
74
75 if verbose {
76 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
77 vec.push(("output", data));
78 }
79
80 childless_record("LogicalJoin", vec)
81 }
82}
83
84impl LogicalJoin {
85 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
86 let core = generic::Join::with_full_output(left, right, join_type, on);
87 Self::with_core(core)
88 }
89
90 pub(crate) fn with_output_indices(
91 left: PlanRef,
92 right: PlanRef,
93 join_type: JoinType,
94 on: Condition,
95 output_indices: Vec<usize>,
96 ) -> Self {
97 let core = generic::Join::new(left, right, on, join_type, output_indices);
98 Self::with_core(core)
99 }
100
101 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
102 let base = PlanBase::new_logical_with_core(&core);
103 LogicalJoin { base, core }
104 }
105
106 pub fn create(
107 left: PlanRef,
108 right: PlanRef,
109 join_type: JoinType,
110 on_clause: ExprImpl,
111 ) -> PlanRef {
112 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
113 }
114
115 pub fn internal_column_num(&self) -> usize {
116 self.core.internal_column_num()
117 }
118
119 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
120 self.core.i2l_col_mapping_ignore_join_type()
121 }
122
123 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
124 self.core.i2r_col_mapping_ignore_join_type()
125 }
126
127 pub fn on(&self) -> &Condition {
129 &self.core.on
130 }
131
132 pub fn core(&self) -> &generic::Join<PlanRef> {
133 &self.core
134 }
135
136 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
138 let input_refs = self
139 .core
140 .on
141 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
142 let index_group = input_refs
143 .ones()
144 .chunk_by(|i| *i < self.core.left.schema().len());
145 let left_index = index_group
146 .into_iter()
147 .next()
148 .map_or(vec![], |group| group.1.collect_vec());
149 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
150 group
151 .1
152 .map(|i| i - self.core.left.schema().len())
153 .collect_vec()
154 });
155 (left_index, right_index)
156 }
157
158 pub fn join_type(&self) -> JoinType {
160 self.core.join_type
161 }
162
163 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
165 self.core.eq_indexes()
166 }
167
168 pub fn output_indices(&self) -> &Vec<usize> {
170 &self.core.output_indices
171 }
172
173 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
175 Self::with_core(generic::Join {
176 output_indices,
177 ..self.core.clone()
178 })
179 }
180
181 pub fn clone_with_cond(&self, on: Condition) -> Self {
183 Self::with_core(generic::Join {
184 on,
185 ..self.core.clone()
186 })
187 }
188
189 pub fn is_left_join(&self) -> bool {
190 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
191 }
192
193 pub fn is_right_join(&self) -> bool {
194 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
195 }
196
197 pub fn is_full_out(&self) -> bool {
198 self.core.is_full_out()
199 }
200
201 pub fn is_asof_join(&self) -> bool {
202 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
203 }
204
205 pub fn output_indices_are_trivial(&self) -> bool {
206 self.output_indices() == &(0..self.internal_column_num()).collect_vec()
207 }
208
209 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
214 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
215 JoinType::LeftOuter => (false, true),
216 JoinType::RightOuter => (true, false),
217 JoinType::FullOuter => (true, true),
218 _ => return join_type,
219 };
220
221 for expr in &predicate.conjunctions {
222 if let ExprImpl::FunctionCall(func) = expr {
223 match func.func_type() {
224 ExprType::Equal
225 | ExprType::NotEqual
226 | ExprType::LessThan
227 | ExprType::LessThanOrEqual
228 | ExprType::GreaterThan
229 | ExprType::GreaterThanOrEqual => {
230 for input in func.inputs() {
231 if let ExprImpl::InputRef(input) = input {
232 let idx = input.index;
233 if idx < left_col_num {
234 gen_null_in_left = false;
235 } else {
236 gen_null_in_right = false;
237 }
238 }
239 }
240 }
241 _ => {}
242 };
243 }
244 }
245
246 match (gen_null_in_left, gen_null_in_right) {
247 (true, true) => JoinType::FullOuter,
248 (true, false) => JoinType::RightOuter,
249 (false, true) => JoinType::LeftOuter,
250 (false, false) => JoinType::Inner,
251 }
252 }
253
254 fn to_batch_lookup_join_with_index_selection(
258 &self,
259 predicate: EqJoinPredicate,
260 batch_join: generic::Join<BatchPlanRef>,
261 ) -> Result<Option<BatchLookupJoin>> {
262 match batch_join.join_type {
263 JoinType::Inner
264 | JoinType::LeftOuter
265 | JoinType::LeftSemi
266 | JoinType::LeftAnti
267 | JoinType::AsofInner
268 | JoinType::AsofLeftOuter => {}
269 _ => return Ok(None),
270 };
271
272 let right = self.right();
274 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
276 logical_scan
277 } else {
278 return Ok(None);
279 };
280
281 let mut result_plan = None;
282 if let Some(lookup_join) =
284 self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
285 {
286 result_plan = Some(lookup_join);
287 }
288
289 if self
290 .core
291 .ctx()
292 .session_ctx()
293 .config()
294 .enable_index_selection()
295 {
296 let indexes = logical_scan.table_indexes();
297 for index in indexes {
298 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
299 let index_scan: PlanRef = index_scan.into();
300 let that = self.clone_with_left_right(self.left(), index_scan.clone());
301 let mut new_batch_join = batch_join.clone();
302 new_batch_join.right =
303 index_scan.to_batch().expect("index scan failed to batch");
304
305 if let Some(lookup_join) =
307 that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
308 {
309 match &result_plan {
310 None => result_plan = Some(lookup_join),
311 Some(prev_lookup_join) => {
312 if prev_lookup_join.lookup_prefix_len()
314 < lookup_join.lookup_prefix_len()
315 {
316 result_plan = Some(lookup_join)
317 }
318 }
319 }
320 }
321 }
322 }
323 }
324
325 Ok(result_plan)
326 }
327
328 fn to_batch_lookup_join(
330 &self,
331 predicate: EqJoinPredicate,
332 logical_join: generic::Join<BatchPlanRef>,
333 ) -> Result<Option<BatchLookupJoin>> {
334 let logical_scan: &LogicalScan =
335 if let Some(logical_scan) = self.core.right.as_logical_scan() {
336 logical_scan
337 } else {
338 return Ok(None);
339 };
340 Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
341 }
342
343 pub fn gen_batch_lookup_join(
344 logical_scan: &LogicalScan,
345 predicate: EqJoinPredicate,
346 logical_join: generic::Join<BatchPlanRef>,
347 is_as_of: bool,
348 ) -> Result<Option<BatchLookupJoin>> {
349 match logical_join.join_type {
350 JoinType::Inner
351 | JoinType::LeftOuter
352 | JoinType::LeftSemi
353 | JoinType::LeftAnti
354 | JoinType::AsofInner
355 | JoinType::AsofLeftOuter => {}
356 _ => return Ok(None),
357 };
358
359 let table = logical_scan.table();
360 let output_column_ids = logical_scan.output_column_ids();
361
362 let order_col_ids = table.order_column_ids();
365 let dist_key = table.distribution_key.clone();
366 let mut dist_key_in_order_key_pos = vec![];
368 for d in dist_key {
369 let pos = table
370 .order_column_indices()
371 .position(|x| x == d)
372 .expect("dist_key must in order_key");
373 dist_key_in_order_key_pos.push(pos);
374 }
375 let shortest_prefix_len = dist_key_in_order_key_pos
377 .iter()
378 .max()
379 .map_or(0, |pos| pos + 1);
380
381 if shortest_prefix_len == 0 {
383 return Ok(None);
384 }
385
386 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
388 for order_col_id in order_col_ids {
389 let mut found = false;
390 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
391 if order_col_id == output_column_ids[eq_idx] {
392 reorder_idx.push(i);
393 found = true;
394 break;
395 }
396 }
397 if !found {
398 break;
399 }
400 }
401 if reorder_idx.len() < shortest_prefix_len {
402 return Ok(None);
403 }
404 let lookup_prefix_len = reorder_idx.len();
405 let predicate = predicate.reorder(&reorder_idx);
406
407 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
409 let o2r = if let Some(project_expr) = project_expr {
411 project_expr
412 .into_iter()
413 .map(|x| x.as_input_ref().unwrap().index)
414 .collect_vec()
415 } else {
416 (0..logical_scan.output_col_idx().len()).collect_vec()
417 };
418 let left_schema_len = logical_join.left.schema().len();
419
420 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
421 offset: left_schema_len,
422 mapping: o2r.clone(),
423 };
424
425 let new_eq_cond = predicate
426 .eq_cond()
427 .rewrite_expr(&mut join_predicate_rewriter);
428
429 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
430 offset: left_schema_len,
431 };
432
433 let new_other_cond = predicate
434 .other_cond()
435 .clone()
436 .rewrite_expr(&mut join_predicate_rewriter)
437 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
438
439 let new_join_on = new_eq_cond.and(new_other_cond);
440 let new_predicate = EqJoinPredicate::create(
441 left_schema_len,
442 new_scan.schema().len(),
443 new_join_on.clone(),
444 );
445
446 if !new_predicate.has_eq() {
449 return Ok(None);
450 }
451
452 let new_join_output_indices = logical_join
455 .output_indices
456 .iter()
457 .map(|&x| {
458 if x < left_schema_len {
459 x
460 } else {
461 o2r[x - left_schema_len] + left_schema_len
462 }
463 })
464 .collect_vec();
465
466 let new_scan_output_column_ids = new_scan.output_column_ids();
467 let as_of = new_scan.as_of.clone();
468 let new_logical_scan: LogicalScan = new_scan.into();
469
470 let new_logical_join = generic::Join::new(
472 logical_join.left,
473 new_logical_scan.to_batch()?,
474 new_join_on,
475 logical_join.join_type,
476 new_join_output_indices,
477 );
478
479 let asof_desc = is_as_of
480 .then(|| {
481 Self::get_inequality_desc_from_predicate(
482 predicate.other_cond().clone(),
483 left_schema_len,
484 )
485 })
486 .transpose()?;
487
488 Ok(Some(BatchLookupJoin::new(
489 new_logical_join,
490 new_predicate,
491 table.clone(),
492 new_scan_output_column_ids,
493 lookup_prefix_len,
494 false,
495 as_of,
496 asof_desc,
497 )))
498 }
499
500 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
501 self.core.decompose()
502 }
503}
504
505impl PlanTreeNodeBinary<Logical> for LogicalJoin {
506 fn left(&self) -> PlanRef {
507 self.core.left.clone()
508 }
509
510 fn right(&self) -> PlanRef {
511 self.core.right.clone()
512 }
513
514 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
515 Self::with_core(generic::Join {
516 left,
517 right,
518 ..self.core.clone()
519 })
520 }
521
522 fn rewrite_with_left_right(
523 &self,
524 left: PlanRef,
525 left_col_change: ColIndexMapping,
526 right: PlanRef,
527 right_col_change: ColIndexMapping,
528 ) -> (Self, ColIndexMapping) {
529 let (new_on, new_output_indices) = {
530 let (mut map, _) = left_col_change.clone().into_parts();
531 let (mut right_map, _) = right_col_change.clone().into_parts();
532 for i in right_map.iter_mut().flatten() {
533 *i += left.schema().len();
534 }
535 map.append(&mut right_map);
536 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
537
538 let new_output_indices = self
539 .output_indices()
540 .iter()
541 .map(|&i| mapping.map(i))
542 .collect::<Vec<_>>();
543 let new_on = self.on().clone().rewrite_expr(&mut mapping);
544 (new_on, new_output_indices)
545 };
546
547 let join = Self::with_output_indices(
548 left,
549 right,
550 self.join_type(),
551 new_on,
552 new_output_indices.clone(),
553 );
554
555 let new_i2o = ColIndexMapping::with_remaining_columns(
556 &new_output_indices,
557 join.internal_column_num(),
558 );
559
560 let old_o2i = self.core.o2i_col_mapping();
561
562 let old_o2l = old_o2i
563 .composite(&self.core.i2l_col_mapping())
564 .composite(&left_col_change);
565 let old_o2r = old_o2i
566 .composite(&self.core.i2r_col_mapping())
567 .composite(&right_col_change);
568 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
569 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
570
571 let out_col_change = old_o2l
572 .composite(&new_l2o)
573 .union(&old_o2r.composite(&new_r2o));
574 (join, out_col_change)
575 }
576}
577
578impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
579
580impl ColPrunable for LogicalJoin {
581 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
582 let required_cols = required_cols
584 .iter()
585 .map(|i| self.output_indices()[*i])
586 .collect_vec();
587 let left_len = self.left().schema().fields.len();
588
589 let total_len = self.left().schema().len() + self.right().schema().len();
590 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
591
592 required_cols.iter().for_each(|&i| {
593 if self.is_right_join() {
594 resized_required_cols.insert(left_len + i);
595 } else {
596 resized_required_cols.insert(i);
597 }
598 });
599
600 let mut visitor = CollectInputRef::new(resized_required_cols);
603 self.on().visit_expr(&mut visitor);
604 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
605
606 let mut left_required_cols = Vec::new();
607 let mut right_required_cols = Vec::new();
608 left_right_required_cols.iter().for_each(|&i| {
609 if i < left_len {
610 left_required_cols.push(i);
611 } else {
612 right_required_cols.push(i - left_len);
613 }
614 });
615
616 let mut on = self.on().clone();
617 let mut mapping =
618 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
619 on = on.rewrite_expr(&mut mapping);
620
621 let new_output_indices = {
622 let required_inputs_in_output = if self.is_left_join() {
623 &left_required_cols
624 } else if self.is_right_join() {
625 &right_required_cols
626 } else {
627 &left_right_required_cols
628 };
629
630 let mapping =
631 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
632 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
633 };
634
635 LogicalJoin::with_output_indices(
636 self.left().prune_col(&left_required_cols, ctx),
637 self.right().prune_col(&right_required_cols, ctx),
638 self.join_type(),
639 on,
640 new_output_indices,
641 )
642 .into()
643 }
644}
645
646impl ExprRewritable<Logical> for LogicalJoin {
647 fn has_rewritable_expr(&self) -> bool {
648 true
649 }
650
651 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
652 let mut core = self.core.clone();
653 core.rewrite_exprs(r);
654 Self {
655 base: self.base.clone_with_new_plan_id(),
656 core,
657 }
658 .into()
659 }
660}
661
662impl ExprVisitable for LogicalJoin {
663 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
664 self.core.visit_exprs(v);
665 }
666}
667
668fn derive_predicate_from_eq_condition(
686 expr: &ExprImpl,
687 eq_condition: &EqJoinPredicate,
688 col_num: usize,
689 expr_is_left: bool,
690) -> Option<ExprImpl> {
691 if expr.is_impure() {
692 return None;
693 }
694 let eq_indices = eq_condition
695 .eq_indexes_typed()
696 .iter()
697 .filter_map(|(l, r)| {
698 if l.return_type() != r.return_type() {
699 None
700 } else if expr_is_left {
701 Some(l.index())
702 } else {
703 Some(r.index())
704 }
705 })
706 .collect_vec();
707 if expr
708 .collect_input_refs(col_num)
709 .ones()
710 .any(|index| !eq_indices.contains(&index))
711 {
712 return None;
714 }
715 let other_side_mapping = if expr_is_left {
718 eq_condition.eq_indexes_typed().into_iter().collect()
719 } else {
720 eq_condition
721 .eq_indexes_typed()
722 .into_iter()
723 .map(|(x, y)| (y, x))
724 .collect()
725 };
726 struct InputRefsRewriter {
727 mapping: HashMap<InputRef, InputRef>,
728 }
729 impl ExprRewriter for InputRefsRewriter {
730 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
731 self.mapping[&input_ref].clone().into()
732 }
733 }
734 Some(
735 InputRefsRewriter {
736 mapping: other_side_mapping,
737 }
738 .rewrite_expr(expr.clone()),
739 )
740}
741
742struct LookupJoinPredicateRewriter {
744 offset: usize,
745 mapping: Vec<usize>,
746}
747impl ExprRewriter for LookupJoinPredicateRewriter {
748 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
749 if input_ref.index() < self.offset {
750 input_ref.into()
751 } else {
752 InputRef::new(
753 self.mapping[input_ref.index() - self.offset] + self.offset,
754 input_ref.return_type(),
755 )
756 .into()
757 }
758 }
759}
760
761struct LookupJoinScanPredicateRewriter {
763 offset: usize,
764}
765impl ExprRewriter for LookupJoinScanPredicateRewriter {
766 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
767 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
768 }
769}
770
771impl PredicatePushdown for LogicalJoin {
772 fn predicate_pushdown(
796 &self,
797 predicate: Condition,
798 ctx: &mut PredicatePushdownContext,
799 ) -> PlanRef {
800 let mut predicate = {
802 let mut mapping = self.core.o2i_col_mapping();
803 predicate.rewrite_expr(&mut mapping)
804 };
805
806 let left_col_num = self.left().schema().len();
807 let right_col_num = self.right().schema().len();
808 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
809
810 let push_down_temporal_predicate = !self.should_be_temporal_join();
811
812 let (left_from_filter, right_from_filter, on) = push_down_into_join(
813 &mut predicate,
814 left_col_num,
815 right_col_num,
816 join_type,
817 push_down_temporal_predicate,
818 );
819
820 let mut new_on = self.on().clone().and(on);
821 let (left_from_on, right_from_on) = push_down_join_condition(
822 &mut new_on,
823 left_col_num,
824 right_col_num,
825 join_type,
826 push_down_temporal_predicate,
827 );
828
829 let left_predicate = left_from_filter.and(left_from_on);
830 let right_predicate = right_from_filter.and(right_from_on);
831
832 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
834
835 let right_from_left = if matches!(
837 join_type,
838 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
839 ) {
840 Condition {
841 conjunctions: left_predicate
842 .conjunctions
843 .iter()
844 .filter_map(|expr| {
845 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
846 })
847 .collect(),
848 }
849 } else {
850 Condition::true_cond()
851 };
852
853 let left_from_right = if matches!(
855 join_type,
856 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
857 ) {
858 Condition {
859 conjunctions: right_predicate
860 .conjunctions
861 .iter()
862 .filter_map(|expr| {
863 derive_predicate_from_eq_condition(
864 expr,
865 &eq_condition,
866 right_col_num,
867 false,
868 )
869 })
870 .collect(),
871 }
872 } else {
873 Condition::true_cond()
874 };
875
876 let left_predicate = left_predicate.and(left_from_right);
877 let right_predicate = right_predicate.and(right_from_left);
878
879 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
880 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
881 let new_join = LogicalJoin::with_output_indices(
882 new_left,
883 new_right,
884 join_type,
885 new_on,
886 self.output_indices().clone(),
887 );
888
889 let mut mapping = self.core.i2o_col_mapping();
890 predicate = predicate.rewrite_expr(&mut mapping);
891 LogicalFilter::create(new_join.into(), predicate)
892 }
893}
894
895impl LogicalJoin {
896 fn get_stream_input_for_hash_join(
897 &self,
898 predicate: &EqJoinPredicate,
899 ctx: &mut ToStreamContext,
900 ) -> Result<(StreamPlanRef, StreamPlanRef)> {
901 use super::stream::prelude::*;
902
903 let lhs_join_key_idx = self.eq_indexes().into_iter().map(|(l, _)| l).collect_vec();
904 let rhs_join_key_idx = self.eq_indexes().into_iter().map(|(_, r)| r).collect_vec();
905
906 let logical_right = try_enforce_locality_requirement(self.right(), &rhs_join_key_idx);
907 let mut right = logical_right.to_stream_with_dist_required(
908 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
909 ctx,
910 )?;
911 let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
912 let r2l =
913 predicate.r2l_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
914 let l2r =
915 predicate.l2r_eq_columns_mapping(logical_left.schema().len(), right.schema().len());
916 let mut left;
917 let right_dist = right.distribution();
918 match right_dist {
919 Distribution::HashShard(_) => {
920 let left_dist = r2l
921 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
922 left = logical_left.to_stream_with_dist_required(&left_dist, ctx)?;
923 }
924 Distribution::UpstreamHashShard(_, _) => {
925 left = logical_left.to_stream_with_dist_required(
926 &RequiredDist::shard_by_key(
927 self.left().schema().len(),
928 &predicate.left_eq_indexes(),
929 ),
930 ctx,
931 )?;
932 let left_dist = left.distribution();
933 match left_dist {
934 Distribution::HashShard(_) => {
935 let right_dist = l2r.rewrite_required_distribution(
936 &RequiredDist::PhysicalDist(left_dist.clone()),
937 );
938 right = right_dist.streaming_enforce_if_not_satisfies(right)?
939 }
940 Distribution::UpstreamHashShard(_, _) => {
941 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
942 .streaming_enforce_if_not_satisfies(left)?;
943 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
944 .streaming_enforce_if_not_satisfies(right)?;
945 }
946 _ => unreachable!(),
947 }
948 }
949 _ => unreachable!(),
950 }
951 Ok((left, right))
952 }
953
954 fn to_stream_hash_join(
955 &self,
956 predicate: EqJoinPredicate,
957 ctx: &mut ToStreamContext,
958 ) -> Result<StreamPlanRef> {
959 use super::stream::prelude::*;
960
961 assert!(predicate.has_eq());
962 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
963
964 let core = self.core.clone_with_inputs(left, right);
965
966 let stream_hash_join = StreamHashJoin::new(core.clone(), predicate.clone())?;
975
976 let force_filter_inside_join = self
977 .base
978 .ctx()
979 .session_ctx()
980 .config()
981 .streaming_force_filter_inside_join();
982
983 let pull_filter = self.join_type() == JoinType::Inner
984 && stream_hash_join.eq_join_predicate().has_non_eq()
985 && stream_hash_join.inequality_pairs().is_empty()
986 && (!force_filter_inside_join);
987 if pull_filter {
988 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
989
990 let mut core = core;
991 core.output_indices = default_indices.clone();
992 let eq_cond = EqJoinPredicate::new(
994 Condition::true_cond(),
995 predicate.eq_keys().to_vec(),
996 self.left().schema().len(),
997 self.right().schema().len(),
998 );
999 core.on = eq_cond.eq_cond();
1000 let hash_join = StreamHashJoin::new(core, eq_cond)?.into();
1001 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1002 let plan = StreamFilter::new(logical_filter).into();
1003 if self.output_indices() != &default_indices {
1004 let logical_project = generic::Project::with_mapping(
1005 plan,
1006 ColIndexMapping::with_remaining_columns(
1007 self.output_indices(),
1008 self.internal_column_num(),
1009 ),
1010 );
1011 Ok(StreamProject::new(logical_project).into())
1012 } else {
1013 Ok(plan)
1014 }
1015 } else {
1016 Ok(stream_hash_join.into())
1017 }
1018 }
1019
1020 fn should_be_temporal_join(&self) -> bool {
1021 let right = self.right();
1022 if let Some(logical_scan) = right.as_logical_scan() {
1023 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1024 } else {
1025 false
1026 }
1027 }
1028
1029 fn to_stream_temporal_join_with_index_selection(
1030 &self,
1031 predicate: EqJoinPredicate,
1032 ctx: &mut ToStreamContext,
1033 ) -> Result<StreamPlanRef> {
1034 let right = self.right();
1036 let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1038
1039 let mut result_plan: Result<StreamTemporalJoin> =
1041 self.to_stream_temporal_join(predicate.clone(), ctx);
1042 if let Ok(temporal_join) = &result_plan
1044 && temporal_join.eq_join_predicate().eq_indexes().len()
1045 == logical_scan.primary_key().len()
1046 {
1047 return result_plan.map(|x| x.into());
1048 }
1049 if self
1050 .core
1051 .ctx()
1052 .session_ctx()
1053 .config()
1054 .enable_index_selection()
1055 {
1056 let indexes = logical_scan.table_indexes();
1057 for index in indexes {
1058 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1060 let index_scan: PlanRef = index_scan.into();
1061 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1062 if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx)
1063 {
1064 match &result_plan {
1065 Err(_) => result_plan = Ok(temporal_join),
1066 Ok(prev_temporal_join) => {
1067 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1069 < temporal_join.eq_join_predicate().eq_indexes().len()
1070 {
1071 result_plan = Ok(temporal_join)
1072 }
1073 }
1074 }
1075 }
1076 }
1077 }
1078 }
1079
1080 result_plan.map(|x| x.into())
1081 }
1082
1083 fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1084 let Some(logical_scan) = right.as_logical_scan() else {
1085 return Err(RwError::from(ErrorCode::NotSupported(
1086 "Temporal join requires a table scan as its lookup table".into(),
1087 "Please provide a table scan".into(),
1088 )));
1089 };
1090
1091 if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1092 return Err(RwError::from(ErrorCode::NotSupported(
1093 "Temporal join requires a table defined as temporal table".into(),
1094 "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1095 )));
1096 }
1097 Ok(logical_scan)
1098 }
1099
1100 fn temporal_join_scan_predicate_pull_up(
1101 logical_scan: &LogicalScan,
1102 predicate: EqJoinPredicate,
1103 output_indices: &[usize],
1104 left_schema_len: usize,
1105 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1106 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1108 let o2r = if let Some(project_expr) = project_expr {
1110 project_expr
1111 .into_iter()
1112 .map(|x| x.as_input_ref().unwrap().index)
1113 .collect_vec()
1114 } else {
1115 (0..logical_scan.output_col_idx().len()).collect_vec()
1116 };
1117 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1118 offset: left_schema_len,
1119 mapping: o2r.clone(),
1120 };
1121
1122 let new_eq_cond = predicate
1123 .eq_cond()
1124 .rewrite_expr(&mut join_predicate_rewriter);
1125
1126 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1127 offset: left_schema_len,
1128 };
1129
1130 let new_other_cond = predicate
1131 .other_cond()
1132 .clone()
1133 .rewrite_expr(&mut join_predicate_rewriter)
1134 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1135
1136 let new_join_on = new_eq_cond.and(new_other_cond);
1137
1138 let new_predicate = EqJoinPredicate::create(
1139 left_schema_len,
1140 new_scan.schema().len(),
1141 new_join_on.clone(),
1142 );
1143
1144 let new_join_output_indices = output_indices
1147 .iter()
1148 .map(|&x| {
1149 if x < left_schema_len {
1150 x
1151 } else {
1152 o2r[x - left_schema_len] + left_schema_len
1153 }
1154 })
1155 .collect_vec();
1156
1157 if new_scan.cross_database() {
1159 return Err(RwError::from(ErrorCode::NotSupported(
1160 "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1161 "Please ensure both tables are in the same database".into(),
1162 )));
1163 }
1164 let new_stream_table_scan =
1165 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1166 Ok((
1167 new_stream_table_scan,
1168 new_predicate,
1169 new_join_on,
1170 new_join_output_indices,
1171 ))
1172 }
1173
1174 fn to_stream_temporal_join(
1175 &self,
1176 predicate: EqJoinPredicate,
1177 ctx: &mut ToStreamContext,
1178 ) -> Result<StreamTemporalJoin> {
1179 use super::stream::prelude::*;
1180
1181 assert!(predicate.has_eq());
1182
1183 let right = self.right();
1184
1185 let logical_scan = Self::check_temporal_rhs(&right)?;
1186
1187 let table = logical_scan.table();
1188 let output_column_ids = logical_scan.output_column_ids();
1189
1190 let order_col_ids = table.order_column_ids();
1193 let dist_key = table.distribution_key.clone();
1194
1195 let mut dist_key_in_order_key_pos = vec![];
1196 for d in dist_key {
1197 let pos = table
1198 .order_column_indices()
1199 .position(|x| x == d)
1200 .expect("dist_key must in order_key");
1201 dist_key_in_order_key_pos.push(pos);
1202 }
1203 let shortest_prefix_len = dist_key_in_order_key_pos
1205 .iter()
1206 .max()
1207 .map_or(0, |pos| pos + 1);
1208
1209 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1211 for order_col_id in order_col_ids {
1212 let mut found = false;
1213 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1214 if order_col_id == output_column_ids[eq_idx] {
1215 reorder_idx.push(i);
1216 found = true;
1217 break;
1218 }
1219 }
1220 if !found {
1221 break;
1222 }
1223 }
1224 if reorder_idx.len() < shortest_prefix_len {
1225 return Err(RwError::from(ErrorCode::NotSupported(
1227 "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1228 "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1229 )));
1230 }
1231 let lookup_prefix_len = reorder_idx.len();
1232 let predicate = predicate.reorder(&reorder_idx);
1233
1234 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1235 RequiredDist::single()
1236 } else {
1237 let left_eq_indexes = predicate.left_eq_indexes();
1238 let left_dist_key = dist_key_in_order_key_pos
1239 .iter()
1240 .map(|pos| left_eq_indexes[*pos])
1241 .collect_vec();
1242
1243 RequiredDist::hash_shard(&left_dist_key)
1244 };
1245
1246 let lhs_join_key_idx = predicate
1247 .eq_indexes()
1248 .into_iter()
1249 .map(|(l, _)| l)
1250 .collect_vec();
1251 let logical_left = try_enforce_locality_requirement(self.left(), &lhs_join_key_idx);
1252 let left = logical_left.to_stream(ctx)?;
1253 let left = required_dist.stream_enforce(left);
1255
1256 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1257 Self::temporal_join_scan_predicate_pull_up(
1258 logical_scan,
1259 predicate,
1260 self.output_indices(),
1261 self.left().schema().len(),
1262 )?;
1263
1264 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1265 if !new_predicate.has_eq() {
1266 return Err(RwError::from(ErrorCode::NotSupported(
1267 "Temporal join requires a non trivial join condition".into(),
1268 "Please remove the false condition of the join".into(),
1269 )));
1270 }
1271
1272 let new_logical_join = generic::Join::new(
1274 left,
1275 right,
1276 new_join_on,
1277 self.join_type(),
1278 new_join_output_indices,
1279 );
1280
1281 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1282
1283 StreamTemporalJoin::new(new_logical_join, new_predicate, false)
1284 }
1285
1286 fn to_stream_nested_loop_temporal_join(
1287 &self,
1288 predicate: EqJoinPredicate,
1289 ctx: &mut ToStreamContext,
1290 ) -> Result<StreamPlanRef> {
1291 use super::stream::prelude::*;
1292 assert!(!predicate.has_eq());
1293
1294 let left = self.left().to_stream_with_dist_required(
1295 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1296 ctx,
1297 )?;
1298 assert!(left.as_stream_exchange().is_some());
1299
1300 if self.join_type() != JoinType::Inner {
1301 return Err(RwError::from(ErrorCode::NotSupported(
1302 "Temporal join requires an inner join".into(),
1303 "Please use an inner join".into(),
1304 )));
1305 }
1306
1307 if !left.append_only() {
1308 return Err(RwError::from(ErrorCode::NotSupported(
1309 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1310 "Please ensure the left hash side is append only".into(),
1311 )));
1312 }
1313
1314 let right = self.right();
1315 let logical_scan = Self::check_temporal_rhs(&right)?;
1316
1317 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1318 Self::temporal_join_scan_predicate_pull_up(
1319 logical_scan,
1320 predicate,
1321 self.output_indices(),
1322 self.left().schema().len(),
1323 )?;
1324
1325 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1326
1327 let new_logical_join = generic::Join::new(
1329 left,
1330 right,
1331 new_join_on,
1332 self.join_type(),
1333 new_join_output_indices,
1334 );
1335
1336 Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true)?.into())
1337 }
1338
1339 fn to_stream_dynamic_filter(
1340 &self,
1341 predicate: Condition,
1342 ctx: &mut ToStreamContext,
1343 ) -> Result<Option<StreamPlanRef>> {
1344 use super::stream::prelude::*;
1345
1346 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1352 return Ok(None);
1353 }
1354
1355 if !self.right().max_one_row() {
1357 return Ok(None);
1358 }
1359 if self.right().schema().len() != 1 {
1360 return Ok(None);
1361 }
1362
1363 if predicate.conjunctions.len() > 1 {
1365 return Ok(None);
1366 }
1367 let expr: ExprImpl = predicate.into();
1368 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1369 Some(v) => v,
1370 None => return Ok(None),
1371 };
1372
1373 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1374 && right_ref.index == self.left().schema().len() ;
1375 if !condition_cross_inputs {
1376 return Ok(None);
1378 }
1379
1380 if self.left().schema().fields()[left_ref.index].data_type
1382 != self.right().schema().fields()[0].data_type
1383 {
1384 return Ok(None);
1385 }
1386
1387 let all_output_from_left = self
1389 .output_indices()
1390 .iter()
1391 .all(|i| *i < self.left().schema().len());
1392 if !all_output_from_left {
1393 return Ok(None);
1394 }
1395
1396 let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1397 let right = self.right().to_stream_with_dist_required(
1398 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1399 ctx,
1400 )?;
1401
1402 assert!(right.as_stream_exchange().is_some());
1403 assert_eq!(
1404 *right.inputs().iter().exactly_one().unwrap().distribution(),
1405 Distribution::Single
1406 );
1407
1408 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1409 let plan = StreamDynamicFilter::new(core)?.into();
1410 if self
1412 .output_indices()
1413 .iter()
1414 .copied()
1415 .ne(0..self.left().schema().len())
1416 {
1417 let logical_project = generic::Project::with_mapping(
1420 plan,
1421 ColIndexMapping::with_remaining_columns(
1422 self.output_indices(),
1423 self.left().schema().len(),
1424 ),
1425 );
1426 Ok(Some(StreamProject::new(logical_project).into()))
1427 } else {
1428 Ok(Some(plan))
1429 }
1430 }
1431
1432 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1433 let predicate = EqJoinPredicate::create(
1434 self.left().schema().len(),
1435 self.right().schema().len(),
1436 self.on().clone(),
1437 );
1438 assert!(predicate.has_eq());
1439
1440 let join = self
1441 .core
1442 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1443
1444 Ok(self
1445 .to_batch_lookup_join(predicate, join)?
1446 .expect("Fail to convert to lookup join")
1447 .into())
1448 }
1449
1450 fn to_stream_asof_join(
1451 &self,
1452 predicate: EqJoinPredicate,
1453 ctx: &mut ToStreamContext,
1454 ) -> Result<StreamPlanRef> {
1455 use super::stream::prelude::*;
1456
1457 if predicate.eq_keys().is_empty() {
1458 return Err(ErrorCode::InvalidInputSyntax(
1459 "AsOf join requires at least 1 equal condition".to_owned(),
1460 )
1461 .into());
1462 }
1463
1464 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1465 let left_len = left.schema().len();
1466 let core = self.core.clone_with_inputs(left, right);
1467
1468 let inequality_desc =
1469 Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1470
1471 Ok(StreamAsOfJoin::new(core, predicate, inequality_desc)?.into())
1472 }
1473
1474 fn to_batch_hash_join(
1476 &self,
1477 logical_join: generic::Join<BatchPlanRef>,
1478 predicate: EqJoinPredicate,
1479 ) -> Result<BatchPlanRef> {
1480 use super::batch::prelude::*;
1481
1482 let left_schema_len = logical_join.left.schema().len();
1483 let asof_desc = self
1484 .is_asof_join()
1485 .then(|| {
1486 Self::get_inequality_desc_from_predicate(
1487 predicate.other_cond().clone(),
1488 left_schema_len,
1489 )
1490 })
1491 .transpose()?;
1492
1493 let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1494 Ok(batch_join.into())
1495 }
1496
1497 pub fn get_inequality_desc_from_predicate(
1498 predicate: Condition,
1499 left_input_len: usize,
1500 ) -> Result<AsOfJoinDesc> {
1501 let expr: ExprImpl = predicate.into();
1502 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1503 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1504 {
1505 Ok(AsOfJoinDesc {
1506 left_idx: left_input_ref.index() as u32,
1507 right_idx: (right_input_ref.index() - left_input_len) as u32,
1508 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1509 })
1510 } else {
1511 bail!("inequal condition from the same side should be push down in optimizer");
1512 }
1513 } else {
1514 Err(ErrorCode::InvalidInputSyntax(
1515 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1516 )
1517 .into())
1518 }
1519 }
1520
1521 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1522 match expr_type {
1523 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1524 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1525 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1526 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1527 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1528 "Invalid comparison type: {}",
1529 expr_type.as_str_name()
1530 ))
1531 .into()),
1532 }
1533 }
1534}
1535
1536impl ToBatch for LogicalJoin {
1537 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1538 let predicate = EqJoinPredicate::create(
1539 self.left().schema().len(),
1540 self.right().schema().len(),
1541 self.on().clone(),
1542 );
1543
1544 let batch_join = self
1545 .core
1546 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1547
1548 let ctx = self.base.ctx();
1549 let config = ctx.session_ctx().config();
1550
1551 if predicate.has_eq() {
1552 if !predicate.eq_keys_are_type_aligned() {
1553 return Err(ErrorCode::InternalError(format!(
1554 "Join eq keys are not aligned for predicate: {predicate:?}"
1555 ))
1556 .into());
1557 }
1558 if config.batch_enable_lookup_join()
1559 && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1560 predicate.clone(),
1561 batch_join.clone(),
1562 )?
1563 {
1564 return Ok(lookup_join.into());
1565 }
1566 self.to_batch_hash_join(batch_join, predicate)
1567 } else if self.is_asof_join() {
1568 Err(ErrorCode::InvalidInputSyntax(
1569 "AsOf join requires at least 1 equal condition".to_owned(),
1570 )
1571 .into())
1572 } else {
1573 Ok(BatchNestedLoopJoin::new(batch_join).into())
1575 }
1576 }
1577}
1578
1579impl ToStream for LogicalJoin {
1580 fn to_stream(
1581 &self,
1582 ctx: &mut ToStreamContext,
1583 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1584 if self
1585 .on()
1586 .conjunctions
1587 .iter()
1588 .any(|cond| cond.count_nows() > 0)
1589 {
1590 return Err(ErrorCode::NotSupported(
1591 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1592 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1593 }
1594
1595 let predicate = EqJoinPredicate::create(
1596 self.left().schema().len(),
1597 self.right().schema().len(),
1598 self.on().clone(),
1599 );
1600
1601 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1602 self.to_stream_asof_join(predicate, ctx)
1603 } else if predicate.has_eq() {
1604 if !predicate.eq_keys_are_type_aligned() {
1605 return Err(ErrorCode::InternalError(format!(
1606 "Join eq keys are not aligned for predicate: {predicate:?}"
1607 ))
1608 .into());
1609 }
1610
1611 if self.should_be_temporal_join() {
1612 self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1613 } else {
1614 self.to_stream_hash_join(predicate, ctx)
1615 }
1616 } else if self.should_be_temporal_join() {
1617 self.to_stream_nested_loop_temporal_join(predicate, ctx)
1618 } else if let Some(dynamic_filter) =
1619 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1620 {
1621 Ok(dynamic_filter)
1622 } else {
1623 Err(RwError::from(ErrorCode::NotSupported(
1624 "streaming nested-loop join".to_owned(),
1625 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1626 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1627 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1628 )))
1629 }
1630 }
1631
1632 fn logical_rewrite_for_stream(
1633 &self,
1634 ctx: &mut RewriteStreamContext,
1635 ) -> Result<(PlanRef, ColIndexMapping)> {
1636 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1637 let left_len = left.schema().len();
1638 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1639 let (join, out_col_change) = self.rewrite_with_left_right(
1640 left.clone(),
1641 left_col_change,
1642 right.clone(),
1643 right_col_change,
1644 );
1645
1646 let mapping = ColIndexMapping::with_remaining_columns(
1647 join.output_indices(),
1648 join.internal_column_num(),
1649 );
1650
1651 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1652 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1653
1654 let mut left_to_add = left
1656 .expect_stream_key()
1657 .iter()
1658 .cloned()
1659 .filter(|i| l2o.try_map(*i).is_none())
1660 .collect_vec();
1661
1662 let mut right_to_add = right
1663 .expect_stream_key()
1664 .iter()
1665 .filter(|&&i| r2o.try_map(i).is_none())
1666 .map(|&i| i + left_len)
1667 .collect_vec();
1668
1669 let right_len = right.schema().len();
1672 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1673
1674 let either_or_both = self.core.add_which_join_key_to_pk();
1675
1676 for (lk, rk) in eq_predicate.eq_indexes() {
1677 match either_or_both {
1678 EitherOrBoth::Left(_) => {
1679 if l2o.try_map(lk).is_none() {
1680 left_to_add.push(lk);
1681 }
1682 }
1683 EitherOrBoth::Right(_) => {
1684 if r2o.try_map(rk).is_none() {
1685 right_to_add.push(rk + left_len)
1686 }
1687 }
1688 EitherOrBoth::Both(_, _) => {
1689 if l2o.try_map(lk).is_none() {
1690 left_to_add.push(lk);
1691 }
1692 if r2o.try_map(rk).is_none() {
1693 right_to_add.push(rk + left_len)
1694 }
1695 }
1696 };
1697 }
1698 let left_to_add = left_to_add.into_iter().unique();
1699 let right_to_add = right_to_add.into_iter().unique();
1700 let mut new_output_indices = join.output_indices().clone();
1703 if !join.is_right_join() {
1704 new_output_indices.extend(left_to_add);
1705 }
1706 if !join.is_left_join() {
1707 new_output_indices.extend(right_to_add);
1708 }
1709
1710 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1711
1712 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1713 let l2o = join_with_pk
1716 .core
1717 .l2i_col_mapping()
1718 .composite(&join_with_pk.core.i2o_col_mapping());
1719 let r2o = join_with_pk
1720 .core
1721 .r2i_col_mapping()
1722 .composite(&join_with_pk.core.i2o_col_mapping());
1723 let left_right_stream_keys = join_with_pk
1724 .left()
1725 .expect_stream_key()
1726 .iter()
1727 .map(|i| l2o.map(*i))
1728 .chain(
1729 join_with_pk
1730 .right()
1731 .expect_stream_key()
1732 .iter()
1733 .map(|i| r2o.map(*i)),
1734 )
1735 .collect_vec();
1736 let plan: PlanRef = join_with_pk.into();
1737 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1738 } else {
1739 join_with_pk.into()
1740 };
1741
1742 Ok((plan, out_col_change))
1744 }
1745
1746 fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1747 let mut ctx = ToStreamContext::new(false);
1748 if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) {
1750 let o2i_mapping = self.core.o2i_col_mapping();
1752 let left_input_columns = columns
1753 .iter()
1754 .map(|&col| o2i_mapping.try_map(col))
1755 .collect::<Option<Vec<usize>>>()?;
1756 if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1757 return Some(
1758 self.clone_with_left_right(better_left_plan, self.right())
1759 .into(),
1760 );
1761 }
1762 }
1763 None
1764 }
1765}
1766
1767#[cfg(test)]
1768mod tests {
1769
1770 use std::collections::HashSet;
1771
1772 use risingwave_common::catalog::{Field, Schema};
1773 use risingwave_common::types::{DataType, Datum};
1774 use risingwave_pb::expr::expr_node::Type;
1775
1776 use super::*;
1777 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1778 use crate::optimizer::optimizer_context::OptimizerContext;
1779 use crate::optimizer::plan_node::LogicalValues;
1780 use crate::optimizer::property::FunctionalDependency;
1781
1782 #[tokio::test]
1796 async fn test_prune_join() {
1797 let ty = DataType::Int32;
1798 let ctx = OptimizerContext::mock().await;
1799 let fields: Vec<Field> = (1..7)
1800 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1801 .collect();
1802 let left = LogicalValues::new(
1803 vec![],
1804 Schema {
1805 fields: fields[0..3].to_vec(),
1806 },
1807 ctx.clone(),
1808 );
1809 let right = LogicalValues::new(
1810 vec![],
1811 Schema {
1812 fields: fields[3..6].to_vec(),
1813 },
1814 ctx,
1815 );
1816 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1817 FunctionCall::new(
1818 Type::Equal,
1819 vec![
1820 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1821 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1822 ],
1823 )
1824 .unwrap(),
1825 ));
1826 let join_type = JoinType::Inner;
1827 let join: PlanRef = LogicalJoin::new(
1828 left.into(),
1829 right.into(),
1830 join_type,
1831 Condition::with_expr(on),
1832 )
1833 .into();
1834
1835 let required_cols = vec![2, 3];
1837 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1838
1839 let join = plan.as_logical_join().unwrap();
1841 assert_eq!(join.schema().fields().len(), 2);
1842 assert_eq!(join.schema().fields()[0], fields[2]);
1843 assert_eq!(join.schema().fields()[1], fields[3]);
1844
1845 let expr: ExprImpl = join.on().clone().into();
1846 let call = expr.as_function_call().unwrap();
1847 assert_eq_input_ref!(&call.inputs()[0], 0);
1848 assert_eq_input_ref!(&call.inputs()[1], 2);
1849
1850 let left = join.left();
1851 let left = left.as_logical_values().unwrap();
1852 assert_eq!(left.schema().fields(), &fields[1..3]);
1853 let right = join.right();
1854 let right = right.as_logical_values().unwrap();
1855 assert_eq!(right.schema().fields(), &fields[3..4]);
1856 }
1857
1858 #[tokio::test]
1860 async fn test_prune_semi_join() {
1861 let ty = DataType::Int32;
1862 let ctx = OptimizerContext::mock().await;
1863 let fields: Vec<Field> = (1..7)
1864 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1865 .collect();
1866 let left = LogicalValues::new(
1867 vec![],
1868 Schema {
1869 fields: fields[0..3].to_vec(),
1870 },
1871 ctx.clone(),
1872 );
1873 let right = LogicalValues::new(
1874 vec![],
1875 Schema {
1876 fields: fields[3..6].to_vec(),
1877 },
1878 ctx,
1879 );
1880 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1881 FunctionCall::new(
1882 Type::Equal,
1883 vec![
1884 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1885 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1886 ],
1887 )
1888 .unwrap(),
1889 ));
1890 for join_type in [
1891 JoinType::LeftSemi,
1892 JoinType::RightSemi,
1893 JoinType::LeftAnti,
1894 JoinType::RightAnti,
1895 ] {
1896 let join = LogicalJoin::new(
1897 left.clone().into(),
1898 right.clone().into(),
1899 join_type,
1900 Condition::with_expr(on.clone()),
1901 );
1902
1903 let offset = if join.is_right_join() { 3 } else { 0 };
1904 let join: PlanRef = join.into();
1905 let required_cols = vec![0];
1907 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1909 let as_plan = plan.as_logical_join().unwrap();
1910 assert_eq!(as_plan.schema().fields().len(), 1);
1912 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1913
1914 let required_cols = vec![0, 1, 2];
1916 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1918 let as_plan = plan.as_logical_join().unwrap();
1919 assert_eq!(as_plan.schema().fields().len(), 3);
1921 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1922 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1923 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1924 }
1925 }
1926
1927 #[tokio::test]
1940 async fn test_prune_join_no_project() {
1941 let ty = DataType::Int32;
1942 let ctx = OptimizerContext::mock().await;
1943 let fields: Vec<Field> = (1..7)
1944 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1945 .collect();
1946 let left = LogicalValues::new(
1947 vec![],
1948 Schema {
1949 fields: fields[0..3].to_vec(),
1950 },
1951 ctx.clone(),
1952 );
1953 let right = LogicalValues::new(
1954 vec![],
1955 Schema {
1956 fields: fields[3..6].to_vec(),
1957 },
1958 ctx,
1959 );
1960 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1961 FunctionCall::new(
1962 Type::Equal,
1963 vec![
1964 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1965 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1966 ],
1967 )
1968 .unwrap(),
1969 ));
1970 let join_type = JoinType::Inner;
1971 let join: PlanRef = LogicalJoin::new(
1972 left.into(),
1973 right.into(),
1974 join_type,
1975 Condition::with_expr(on),
1976 )
1977 .into();
1978
1979 let required_cols = vec![1, 3];
1981 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1982
1983 let join = plan.as_logical_join().unwrap();
1985 assert_eq!(join.schema().fields().len(), 2);
1986 assert_eq!(join.schema().fields()[0], fields[1]);
1987 assert_eq!(join.schema().fields()[1], fields[3]);
1988
1989 let expr: ExprImpl = join.on().clone().into();
1990 let call = expr.as_function_call().unwrap();
1991 assert_eq_input_ref!(&call.inputs()[0], 0);
1992 assert_eq_input_ref!(&call.inputs()[1], 1);
1993
1994 let left = join.left();
1995 let left = left.as_logical_values().unwrap();
1996 assert_eq!(left.schema().fields(), &fields[1..2]);
1997 let right = join.right();
1998 let right = right.as_logical_values().unwrap();
1999 assert_eq!(right.schema().fields(), &fields[3..4]);
2000 }
2001
2002 #[tokio::test]
2016 async fn test_join_to_batch() {
2017 let ctx = OptimizerContext::mock().await;
2018 let fields: Vec<Field> = (1..7)
2019 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2020 .collect();
2021 let left = LogicalValues::new(
2022 vec![],
2023 Schema {
2024 fields: fields[0..3].to_vec(),
2025 },
2026 ctx.clone(),
2027 );
2028 let right = LogicalValues::new(
2029 vec![],
2030 Schema {
2031 fields: fields[3..6].to_vec(),
2032 },
2033 ctx,
2034 );
2035
2036 fn input_ref(i: usize) -> ExprImpl {
2037 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2038 }
2039 let eq_cond = ExprImpl::FunctionCall(Box::new(
2040 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2041 ));
2042 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2043 FunctionCall::new(
2044 Type::Equal,
2045 vec![
2046 input_ref(2),
2047 ExprImpl::Literal(Box::new(Literal::new(
2048 Datum::Some(42_i32.into()),
2049 DataType::Int32,
2050 ))),
2051 ],
2052 )
2053 .unwrap(),
2054 ));
2055 let on_cond = ExprImpl::FunctionCall(Box::new(
2057 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2058 ));
2059
2060 let join_type = JoinType::Inner;
2061 let logical_join = LogicalJoin::new(
2062 left.into(),
2063 right.into(),
2064 join_type,
2065 Condition::with_expr(on_cond),
2066 );
2067
2068 let result = logical_join.to_batch().unwrap();
2070
2071 let hash_join = result.as_batch_hash_join().unwrap();
2073 assert_eq!(
2074 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2075 eq_cond
2076 );
2077 assert_eq!(
2078 *hash_join
2079 .eq_join_predicate()
2080 .non_eq_cond()
2081 .conjunctions
2082 .first()
2083 .unwrap(),
2084 non_eq_cond
2085 );
2086 }
2087
2088 #[tokio::test]
2101 #[ignore] async fn test_join_to_stream() {
2104 }
2172 #[tokio::test]
2186 async fn test_join_column_prune_with_order_required() {
2187 let ty = DataType::Int32;
2188 let ctx = OptimizerContext::mock().await;
2189 let fields: Vec<Field> = (1..7)
2190 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2191 .collect();
2192 let left = LogicalValues::new(
2193 vec![],
2194 Schema {
2195 fields: fields[0..3].to_vec(),
2196 },
2197 ctx.clone(),
2198 );
2199 let right = LogicalValues::new(
2200 vec![],
2201 Schema {
2202 fields: fields[3..6].to_vec(),
2203 },
2204 ctx,
2205 );
2206 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2207 FunctionCall::new(
2208 Type::Equal,
2209 vec![
2210 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2211 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2212 ],
2213 )
2214 .unwrap(),
2215 ));
2216 let join_type = JoinType::Inner;
2217 let join: PlanRef = LogicalJoin::new(
2218 left.into(),
2219 right.into(),
2220 join_type,
2221 Condition::with_expr(on),
2222 )
2223 .into();
2224
2225 let required_cols = vec![3, 2];
2227 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2228
2229 let join = plan.as_logical_join().unwrap();
2231 assert_eq!(join.schema().fields().len(), 2);
2232 assert_eq!(join.schema().fields()[0], fields[3]);
2233 assert_eq!(join.schema().fields()[1], fields[2]);
2234
2235 let expr: ExprImpl = join.on().clone().into();
2236 let call = expr.as_function_call().unwrap();
2237 assert_eq_input_ref!(&call.inputs()[0], 0);
2238 assert_eq_input_ref!(&call.inputs()[1], 2);
2239
2240 let left = join.left();
2241 let left = left.as_logical_values().unwrap();
2242 assert_eq!(left.schema().fields(), &fields[1..3]);
2243 let right = join.right();
2244 let right = right.as_logical_values().unwrap();
2245 assert_eq!(right.schema().fields(), &fields[3..4]);
2246 }
2247
2248 #[tokio::test]
2249 async fn fd_derivation_inner_outer_join() {
2250 let ctx = OptimizerContext::mock().await;
2273 let left = {
2274 let fields: Vec<Field> = vec![
2275 Field::with_name(DataType::Int32, "l0"),
2276 Field::with_name(DataType::Int32, "l1"),
2277 ];
2278 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2279 values
2281 .base
2282 .functional_dependency_mut()
2283 .add_functional_dependency_by_column_indices(&[0], &[1]);
2284 values
2285 };
2286 let right = {
2287 let fields: Vec<Field> = vec![
2288 Field::with_name(DataType::Int32, "r0"),
2289 Field::with_name(DataType::Int32, "r1"),
2290 Field::with_name(DataType::Int32, "r2"),
2291 ];
2292 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2293 values
2295 .base
2296 .functional_dependency_mut()
2297 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2298 values
2299 };
2300 let on: ExprImpl = FunctionCall::new(
2302 Type::And,
2303 vec![
2304 FunctionCall::new(
2305 Type::Equal,
2306 vec![
2307 InputRef::new(0, DataType::Int32).into(),
2308 ExprImpl::literal_int(0),
2309 ],
2310 )
2311 .unwrap()
2312 .into(),
2313 FunctionCall::new(
2314 Type::Equal,
2315 vec![
2316 InputRef::new(1, DataType::Int32).into(),
2317 InputRef::new(3, DataType::Int32).into(),
2318 ],
2319 )
2320 .unwrap()
2321 .into(),
2322 ],
2323 )
2324 .unwrap()
2325 .into();
2326 let expected_fd_set = [
2327 (
2328 JoinType::Inner,
2329 [
2330 FunctionalDependency::with_indices(5, &[0], &[1]),
2332 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2334 FunctionalDependency::with_indices(5, &[], &[0]),
2336 FunctionalDependency::with_indices(5, &[1], &[3]),
2338 FunctionalDependency::with_indices(5, &[3], &[1]),
2339 ]
2340 .into_iter()
2341 .collect::<HashSet<_>>(),
2342 ),
2343 (JoinType::FullOuter, HashSet::new()),
2344 (
2345 JoinType::RightOuter,
2346 [
2347 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2349 ]
2350 .into_iter()
2351 .collect::<HashSet<_>>(),
2352 ),
2353 (
2354 JoinType::LeftOuter,
2355 [
2356 FunctionalDependency::with_indices(5, &[0], &[1]),
2358 ]
2359 .into_iter()
2360 .collect::<HashSet<_>>(),
2361 ),
2362 (
2363 JoinType::LeftSemi,
2364 [
2365 FunctionalDependency::with_indices(2, &[0], &[1]),
2367 ]
2368 .into_iter()
2369 .collect::<HashSet<_>>(),
2370 ),
2371 (
2372 JoinType::LeftAnti,
2373 [
2374 FunctionalDependency::with_indices(2, &[0], &[1]),
2376 ]
2377 .into_iter()
2378 .collect::<HashSet<_>>(),
2379 ),
2380 (
2381 JoinType::RightSemi,
2382 [
2383 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2385 ]
2386 .into_iter()
2387 .collect::<HashSet<_>>(),
2388 ),
2389 (
2390 JoinType::RightAnti,
2391 [
2392 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2394 ]
2395 .into_iter()
2396 .collect::<HashSet<_>>(),
2397 ),
2398 ];
2399
2400 for (join_type, expected_res) in expected_fd_set {
2401 let join = LogicalJoin::new(
2402 left.clone().into(),
2403 right.clone().into(),
2404 join_type,
2405 Condition::with_expr(on.clone()),
2406 );
2407 let fd_set = join
2408 .functional_dependency()
2409 .as_dependencies()
2410 .iter()
2411 .cloned()
2412 .collect::<HashSet<_>>();
2413 assert_eq!(fd_set, expected_res);
2414 }
2415 }
2416}