1use std::collections::HashMap;
16use std::ops::Deref;
17
18use fixedbitset::FixedBitSet;
19use itertools::{EitherOrBoth, Itertools};
20use pretty_xmlish::{Pretty, XmlNode};
21use risingwave_expr::bail;
22use risingwave_pb::expr::expr_node::PbType;
23use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
24use risingwave_pb::stream_plan::StreamScanType;
25use risingwave_sqlparser::ast::AsOf;
26
27use super::generic::{
28 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
29};
30use super::utils::{Distill, childless_record};
31use super::{
32 BackfillType, BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
33 PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject,
34 ToBatch, ToStream, generic, try_enforce_locality_requirement,
35};
36use crate::error::{ErrorCode, Result, RwError};
37use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
38use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
39use crate::optimizer::plan_node::generic::DynamicFilter;
40use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
41use crate::optimizer::plan_node::utils::IndicesDisplay;
42use crate::optimizer::plan_node::{
43 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
44 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
45 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
46};
47use crate::optimizer::plan_visitor::LogicalCardinalityExt;
48use crate::optimizer::property::{Distribution, RequiredDist};
49use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct LogicalJoin {
59 pub base: PlanBase<Logical>,
60 core: generic::Join<PlanRef>,
61}
62
63impl Distill for LogicalJoin {
64 fn distill<'a>(&self) -> XmlNode<'a> {
65 let verbose = self.base.ctx().is_explain_verbose();
66 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
67 vec.push(("type", Pretty::debug(&self.join_type())));
68
69 let concat_schema = self.core.concat_schema();
70 let cond = Pretty::debug(&ConditionDisplay {
71 condition: self.on(),
72 input_schema: &concat_schema,
73 });
74 vec.push(("on", cond));
75
76 if verbose {
77 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
78 vec.push(("output", data));
79 }
80
81 childless_record("LogicalJoin", vec)
82 }
83}
84
85impl LogicalJoin {
86 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
87 let core = generic::Join::with_full_output(left, right, join_type, on);
88 Self::with_core(core)
89 }
90
91 pub(crate) fn with_output_indices(
92 left: PlanRef,
93 right: PlanRef,
94 join_type: JoinType,
95 on: Condition,
96 output_indices: Vec<usize>,
97 ) -> Self {
98 let core = generic::Join::new(left, right, on, join_type, output_indices);
99 Self::with_core(core)
100 }
101
102 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
103 let base = PlanBase::new_logical_with_core(&core);
104 LogicalJoin { base, core }
105 }
106
107 pub fn create(
108 left: PlanRef,
109 right: PlanRef,
110 join_type: JoinType,
111 on_clause: ExprImpl,
112 ) -> PlanRef {
113 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
114 }
115
116 pub fn internal_column_num(&self) -> usize {
117 self.core.internal_column_num()
118 }
119
120 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
121 self.core.i2l_col_mapping_ignore_join_type()
122 }
123
124 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
125 self.core.i2r_col_mapping_ignore_join_type()
126 }
127
128 pub fn on(&self) -> &Condition {
130 self.core
131 .on
132 .as_condition_ref()
133 .expect("logical join should store predicate as Condition")
134 }
135
136 pub fn core(&self) -> &generic::Join<PlanRef> {
137 &self.core
138 }
139
140 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
142 let input_refs = self
143 .core
144 .on
145 .as_condition_ref()
146 .expect("logical join should store predicate as Condition")
147 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
148 let index_group = input_refs
149 .ones()
150 .chunk_by(|i| *i < self.core.left.schema().len());
151 let left_index = index_group
152 .into_iter()
153 .next()
154 .map_or(vec![], |group| group.1.collect_vec());
155 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
156 group
157 .1
158 .map(|i| i - self.core.left.schema().len())
159 .collect_vec()
160 });
161 (left_index, right_index)
162 }
163
164 pub fn join_type(&self) -> JoinType {
166 self.core.join_type
167 }
168
169 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
171 self.core.eq_indexes()
172 }
173
174 pub fn output_indices(&self) -> &Vec<usize> {
176 &self.core.output_indices
177 }
178
179 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
181 Self::with_core(generic::Join {
182 output_indices,
183 ..self.core.clone()
184 })
185 }
186
187 pub fn clone_with_cond(&self, on: Condition) -> Self {
189 Self::with_core(generic::Join {
190 on: generic::JoinOn::Condition(on),
191 ..self.core.clone()
192 })
193 }
194
195 pub fn is_left_join(&self) -> bool {
196 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
197 }
198
199 pub fn is_right_join(&self) -> bool {
200 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
201 }
202
203 pub fn is_full_out(&self) -> bool {
204 self.core.is_full_out()
205 }
206
207 pub fn is_asof_join(&self) -> bool {
208 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
209 }
210
211 pub fn output_indices_are_trivial(&self) -> bool {
212 itertools::equal(
213 self.output_indices().iter().cloned(),
214 0..self.internal_column_num(),
215 )
216 }
217
218 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
223 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
224 JoinType::LeftOuter => (false, true),
225 JoinType::RightOuter => (true, false),
226 JoinType::FullOuter => (true, true),
227 _ => return join_type,
228 };
229
230 for expr in &predicate.conjunctions {
231 if let ExprImpl::FunctionCall(func) = expr {
232 match func.func_type() {
233 ExprType::Equal
234 | ExprType::NotEqual
235 | ExprType::LessThan
236 | ExprType::LessThanOrEqual
237 | ExprType::GreaterThan
238 | ExprType::GreaterThanOrEqual => {
239 for input in func.inputs() {
240 if let ExprImpl::InputRef(input) = input {
241 let idx = input.index;
242 if idx < left_col_num {
243 gen_null_in_left = false;
244 } else {
245 gen_null_in_right = false;
246 }
247 }
248 }
249 }
250 _ => {}
251 };
252 }
253 }
254
255 match (gen_null_in_left, gen_null_in_right) {
256 (true, true) => JoinType::FullOuter,
257 (true, false) => JoinType::RightOuter,
258 (false, true) => JoinType::LeftOuter,
259 (false, false) => JoinType::Inner,
260 }
261 }
262
263 fn to_batch_lookup_join_with_index_selection(
267 &self,
268 predicate: EqJoinPredicate,
269 batch_join: generic::Join<BatchPlanRef>,
270 ) -> Result<Option<BatchLookupJoin>> {
271 match batch_join.join_type {
272 JoinType::Inner
273 | JoinType::LeftOuter
274 | JoinType::LeftSemi
275 | JoinType::LeftAnti
276 | JoinType::AsofInner
277 | JoinType::AsofLeftOuter => {}
278 _ => return Ok(None),
279 };
280
281 let right = self.right();
283 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
285 logical_scan
286 } else {
287 return Ok(None);
288 };
289
290 let mut result_plan = None;
291 if let Some(lookup_join) =
293 self.to_batch_lookup_join(predicate.clone(), batch_join.clone())?
294 {
295 result_plan = Some(lookup_join);
296 }
297
298 if self
299 .core
300 .ctx()
301 .session_ctx()
302 .config()
303 .enable_index_selection()
304 {
305 let indexes = logical_scan.table_indexes();
306 for index in indexes {
307 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
308 let index_scan: PlanRef = index_scan.into();
309 let that = self.clone_with_left_right(self.left(), index_scan.clone());
310 let mut new_batch_join = batch_join.clone();
311 new_batch_join.right =
312 index_scan.to_batch().expect("index scan failed to batch");
313
314 if let Some(lookup_join) =
316 that.to_batch_lookup_join(predicate.clone(), new_batch_join)?
317 {
318 match &result_plan {
319 None => result_plan = Some(lookup_join),
320 Some(prev_lookup_join) => {
321 if prev_lookup_join.lookup_prefix_len()
323 < lookup_join.lookup_prefix_len()
324 {
325 result_plan = Some(lookup_join)
326 }
327 }
328 }
329 }
330 }
331 }
332 }
333
334 Ok(result_plan)
335 }
336
337 fn to_batch_lookup_join(
339 &self,
340 predicate: EqJoinPredicate,
341 logical_join: generic::Join<BatchPlanRef>,
342 ) -> Result<Option<BatchLookupJoin>> {
343 let logical_scan: &LogicalScan =
344 if let Some(logical_scan) = self.core.right.as_logical_scan() {
345 logical_scan
346 } else {
347 return Ok(None);
348 };
349 Self::gen_batch_lookup_join(logical_scan, predicate, logical_join, self.is_asof_join())
350 }
351
352 pub fn gen_batch_lookup_join(
353 logical_scan: &LogicalScan,
354 predicate: EqJoinPredicate,
355 logical_join: generic::Join<BatchPlanRef>,
356 is_as_of: bool,
357 ) -> Result<Option<BatchLookupJoin>> {
358 match logical_join.join_type {
359 JoinType::Inner
360 | JoinType::LeftOuter
361 | JoinType::LeftSemi
362 | JoinType::LeftAnti
363 | JoinType::AsofInner
364 | JoinType::AsofLeftOuter => {}
365 _ => return Ok(None),
366 };
367
368 let table = logical_scan.table();
369 let output_column_ids = logical_scan.output_column_ids();
370
371 let order_col_ids = table.order_column_ids();
374 let dist_key = table.distribution_key.clone();
375 let mut dist_key_in_order_key_pos = vec![];
377 for d in dist_key {
378 let pos = table
379 .order_column_indices()
380 .position(|x| x == d)
381 .expect("dist_key must in order_key");
382 dist_key_in_order_key_pos.push(pos);
383 }
384 let shortest_prefix_len = dist_key_in_order_key_pos
386 .iter()
387 .max()
388 .map_or(0, |pos| pos + 1);
389
390 if shortest_prefix_len == 0 {
392 return Ok(None);
393 }
394
395 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
397 for order_col_id in order_col_ids {
398 let mut found = false;
399 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
400 if order_col_id == output_column_ids[eq_idx] {
401 reorder_idx.push(i);
402 found = true;
403 break;
404 }
405 }
406 if !found {
407 break;
408 }
409 }
410 if reorder_idx.len() < shortest_prefix_len {
411 return Ok(None);
412 }
413 let lookup_prefix_len = reorder_idx.len();
414 let predicate = predicate.reorder(&reorder_idx);
415
416 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
418 let o2r = if let Some(project_expr) = project_expr {
420 project_expr
421 .into_iter()
422 .map(|x| x.as_input_ref().unwrap().index)
423 .collect_vec()
424 } else {
425 (0..logical_scan.output_col_idx().len()).collect_vec()
426 };
427 let left_schema_len = logical_join.left.schema().len();
428
429 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
430 offset: left_schema_len,
431 mapping: o2r.clone(),
432 };
433
434 let new_eq_cond = predicate
435 .eq_cond()
436 .rewrite_expr(&mut join_predicate_rewriter);
437
438 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
439 offset: left_schema_len,
440 };
441
442 let new_other_cond = predicate
443 .other_cond()
444 .clone()
445 .rewrite_expr(&mut join_predicate_rewriter)
446 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
447
448 let new_join_on = new_eq_cond.and(new_other_cond);
449 let new_predicate =
450 EqJoinPredicate::create(left_schema_len, new_scan.schema().len(), new_join_on);
451
452 if !new_predicate.has_eq() {
455 return Ok(None);
456 }
457
458 let new_join_output_indices = logical_join
461 .output_indices
462 .iter()
463 .map(|&x| {
464 if x < left_schema_len {
465 x
466 } else {
467 o2r[x - left_schema_len] + left_schema_len
468 }
469 })
470 .collect_vec();
471
472 let new_scan_output_column_ids = new_scan.output_column_ids();
473 let as_of = new_scan.as_of.clone();
474 let new_logical_scan: LogicalScan = new_scan.into();
475
476 let new_logical_join = generic::Join::new_with_eq_predicate(
478 logical_join.left,
479 new_logical_scan.to_batch()?,
480 new_predicate,
481 logical_join.join_type,
482 new_join_output_indices,
483 );
484
485 let asof_desc = is_as_of
486 .then(|| {
487 Self::get_inequality_desc_from_predicate(
488 predicate.other_cond().clone(),
489 left_schema_len,
490 )
491 })
492 .transpose()?;
493
494 Ok(Some(BatchLookupJoin::new(
495 new_logical_join,
496 table.clone(),
497 new_scan_output_column_ids,
498 lookup_prefix_len,
499 false,
500 as_of,
501 asof_desc,
502 )))
503 }
504
505 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
506 self.core.decompose()
507 }
508
509 fn dynamic_filter_candidate(&self, predicate: &Condition) -> Option<(usize, PbType)> {
510 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
512 return None;
513 }
514
515 if !self.right().max_one_row() || self.right().schema().len() != 1 {
517 return None;
518 }
519
520 if predicate.conjunctions.len() > 1 {
522 return None;
523 }
524 let expr: ExprImpl = predicate.clone().into();
525 let (left_ref, comparator, right_ref) = expr.as_comparison_cond()?;
526
527 let left_len = self.left().schema().len();
529 let condition_cross_inputs = left_ref.index < left_len && right_ref.index == left_len;
530 if !condition_cross_inputs {
531 return None;
532 }
533
534 if self.left().schema().fields()[left_ref.index].data_type
536 != self.right().schema().fields()[0].data_type
537 {
538 return None;
539 }
540
541 if !self.output_indices().iter().all(|i| *i < left_len) {
543 return None;
544 }
545
546 Some((left_ref.index, comparator))
547 }
548
549 fn temporal_filter_candidate(&self) -> bool {
555 self.right().as_logical_now().is_some()
556 && self.dynamic_filter_candidate(self.on()).is_some()
557 }
558}
559
560impl PlanTreeNodeBinary<Logical> for LogicalJoin {
561 fn left(&self) -> PlanRef {
562 self.core.left.clone()
563 }
564
565 fn right(&self) -> PlanRef {
566 self.core.right.clone()
567 }
568
569 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
570 Self::with_core(generic::Join {
571 left,
572 right,
573 ..self.core.clone()
574 })
575 }
576
577 fn rewrite_with_left_right(
578 &self,
579 left: PlanRef,
580 left_col_change: ColIndexMapping,
581 right: PlanRef,
582 right_col_change: ColIndexMapping,
583 ) -> (Self, ColIndexMapping) {
584 let (new_on, new_output_indices) = {
585 let (mut map, _) = left_col_change.clone().into_parts();
586 let (mut right_map, _) = right_col_change.clone().into_parts();
587 for i in right_map.iter_mut().flatten() {
588 *i += left.schema().len();
589 }
590 map.append(&mut right_map);
591 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
592
593 let new_output_indices = self
594 .output_indices()
595 .iter()
596 .map(|&i| mapping.map(i))
597 .collect::<Vec<_>>();
598 let new_on = self.on().clone().rewrite_expr(&mut mapping);
599 (new_on, new_output_indices)
600 };
601
602 let join = Self::with_output_indices(
603 left,
604 right,
605 self.join_type(),
606 new_on,
607 new_output_indices.clone(),
608 );
609
610 let new_i2o = ColIndexMapping::with_remaining_columns(
611 &new_output_indices,
612 join.internal_column_num(),
613 );
614
615 let old_o2i = self.core.o2i_col_mapping();
616
617 let old_o2l = old_o2i
618 .composite(&self.core.i2l_col_mapping())
619 .composite(&left_col_change);
620 let old_o2r = old_o2i
621 .composite(&self.core.i2r_col_mapping())
622 .composite(&right_col_change);
623 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
624 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
625
626 let out_col_change = old_o2l
627 .composite(&new_l2o)
628 .union(&old_o2r.composite(&new_r2o));
629 (join, out_col_change)
630 }
631}
632
633impl_plan_tree_node_for_binary! { Logical, LogicalJoin }
634
635impl ColPrunable for LogicalJoin {
636 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
637 let required_cols = required_cols
639 .iter()
640 .map(|i| self.output_indices()[*i])
641 .collect_vec();
642 let left_len = self.left().schema().fields.len();
643
644 let total_len = self.left().schema().len() + self.right().schema().len();
645 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
646
647 required_cols.iter().for_each(|&i| {
648 if self.is_right_join() {
649 resized_required_cols.insert(left_len + i);
650 } else {
651 resized_required_cols.insert(i);
652 }
653 });
654
655 let mut visitor = CollectInputRef::new(resized_required_cols);
658 self.on().visit_expr(&mut visitor);
659 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
660
661 let mut left_required_cols = Vec::new();
662 let mut right_required_cols = Vec::new();
663 left_right_required_cols.iter().for_each(|&i| {
664 if i < left_len {
665 left_required_cols.push(i);
666 } else {
667 right_required_cols.push(i - left_len);
668 }
669 });
670
671 let mut on = self.on().clone();
672 let mut mapping =
673 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
674 on = on.rewrite_expr(&mut mapping);
675
676 let new_output_indices = {
677 let required_inputs_in_output = if self.is_left_join() {
678 &left_required_cols
679 } else if self.is_right_join() {
680 &right_required_cols
681 } else {
682 &left_right_required_cols
683 };
684
685 let mapping =
686 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
687 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
688 };
689
690 LogicalJoin::with_output_indices(
691 self.left().prune_col(&left_required_cols, ctx),
692 self.right().prune_col(&right_required_cols, ctx),
693 self.join_type(),
694 on,
695 new_output_indices,
696 )
697 .into()
698 }
699}
700
701impl ExprRewritable<Logical> for LogicalJoin {
702 fn has_rewritable_expr(&self) -> bool {
703 true
704 }
705
706 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
707 let mut core = self.core.clone();
708 core.rewrite_exprs(r);
709 Self {
710 base: self.base.clone_with_new_plan_id(),
711 core,
712 }
713 .into()
714 }
715}
716
717impl ExprVisitable for LogicalJoin {
718 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
719 self.core.visit_exprs(v);
720 }
721}
722
723fn derive_predicate_from_eq_condition(
741 expr: &ExprImpl,
742 eq_condition: &EqJoinPredicate,
743 col_num: usize,
744 expr_is_left: bool,
745) -> Option<ExprImpl> {
746 if expr.is_impure() {
747 return None;
748 }
749 let eq_indices = eq_condition
750 .eq_indexes_typed()
751 .iter()
752 .filter_map(|(l, r)| {
753 if l.return_type() != r.return_type() {
754 None
755 } else if expr_is_left {
756 Some(l.index())
757 } else {
758 Some(r.index())
759 }
760 })
761 .collect_vec();
762 if expr
763 .collect_input_refs(col_num)
764 .ones()
765 .any(|index| !eq_indices.contains(&index))
766 {
767 return None;
769 }
770 let other_side_mapping = if expr_is_left {
773 eq_condition.eq_indexes_typed().into_iter().collect()
774 } else {
775 eq_condition
776 .eq_indexes_typed()
777 .into_iter()
778 .map(|(x, y)| (y, x))
779 .collect()
780 };
781 struct InputRefsRewriter {
782 mapping: HashMap<InputRef, InputRef>,
783 }
784 impl ExprRewriter for InputRefsRewriter {
785 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
786 self.mapping[&input_ref].clone().into()
787 }
788 }
789 Some(
790 InputRefsRewriter {
791 mapping: other_side_mapping,
792 }
793 .rewrite_expr(expr.clone()),
794 )
795}
796
797struct LookupJoinPredicateRewriter {
799 offset: usize,
800 mapping: Vec<usize>,
801}
802impl ExprRewriter for LookupJoinPredicateRewriter {
803 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
804 if input_ref.index() < self.offset {
805 input_ref.into()
806 } else {
807 InputRef::new(
808 self.mapping[input_ref.index() - self.offset] + self.offset,
809 input_ref.return_type(),
810 )
811 .into()
812 }
813 }
814}
815
816struct LookupJoinScanPredicateRewriter {
818 offset: usize,
819}
820impl ExprRewriter for LookupJoinScanPredicateRewriter {
821 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
822 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
823 }
824}
825
826impl PredicatePushdown for LogicalJoin {
827 fn predicate_pushdown(
851 &self,
852 predicate: Condition,
853 ctx: &mut PredicatePushdownContext,
854 ) -> PlanRef {
855 let mut predicate = {
857 let mut mapping = self.core.o2i_col_mapping();
858 predicate.rewrite_expr(&mut mapping)
859 };
860
861 let left_col_num = self.left().schema().len();
862 let right_col_num = self.right().schema().len();
863 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
864
865 let push_down_temporal_predicate = self.temporal_join_on().is_none();
866
867 let (left_from_filter, right_from_filter, on) = push_down_into_join(
868 &mut predicate,
869 left_col_num,
870 right_col_num,
871 join_type,
872 push_down_temporal_predicate,
873 );
874
875 let mut new_on = self.on().clone().and(on);
876 let (left_from_on, right_from_on) = push_down_join_condition(
877 &mut new_on,
878 left_col_num,
879 right_col_num,
880 join_type,
881 push_down_temporal_predicate,
882 );
883
884 let left_predicate = left_from_filter.and(left_from_on);
885 let right_predicate = right_from_filter.and(right_from_on);
886
887 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
889
890 let right_from_left = if matches!(
892 join_type,
893 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
894 ) {
895 Condition {
896 conjunctions: left_predicate
897 .conjunctions
898 .iter()
899 .filter_map(|expr| {
900 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
901 })
902 .collect(),
903 }
904 } else {
905 Condition::true_cond()
906 };
907
908 let left_from_right = if matches!(
910 join_type,
911 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
912 ) {
913 Condition {
914 conjunctions: right_predicate
915 .conjunctions
916 .iter()
917 .filter_map(|expr| {
918 derive_predicate_from_eq_condition(
919 expr,
920 &eq_condition,
921 right_col_num,
922 false,
923 )
924 })
925 .collect(),
926 }
927 } else {
928 Condition::true_cond()
929 };
930
931 let left_predicate = left_predicate.and(left_from_right);
932 let right_predicate = right_predicate.and(right_from_left);
933
934 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
935 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
936 let new_join = LogicalJoin::with_output_indices(
937 new_left,
938 new_right,
939 join_type,
940 new_on,
941 self.output_indices().clone(),
942 );
943
944 let mut mapping = self.core.i2o_col_mapping();
945 predicate = predicate.rewrite_expr(&mut mapping);
946 LogicalFilter::create(new_join.into(), predicate)
947 }
948}
949
950#[derive(Clone, Copy)]
951struct TemporalJoinScan<'a>(&'a LogicalScan);
952
953impl<'a> Deref for TemporalJoinScan<'a> {
954 type Target = LogicalScan;
955
956 fn deref(&self) -> &Self::Target {
957 self.0
958 }
959}
960
961impl LogicalJoin {
962 fn get_stream_input_for_hash_join(
963 &self,
964 predicate: &EqJoinPredicate,
965 ctx: &mut ToStreamContext,
966 ) -> Result<(StreamPlanRef, StreamPlanRef)> {
967 use super::stream::prelude::*;
968
969 let mut right = self.right().to_stream_with_dist_required(
970 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
971 ctx,
972 )?;
973 let r2l =
974 predicate.r2l_eq_columns_mapping(self.left().schema().len(), right.schema().len());
975 let l2r =
976 predicate.l2r_eq_columns_mapping(self.left().schema().len(), right.schema().len());
977 let mut left;
978 let right_dist = right.distribution();
979 match right_dist {
980 Distribution::HashShard(_) => {
981 let left_dist = r2l
982 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
983 left = self.left().to_stream_with_dist_required(&left_dist, ctx)?;
984 }
985 Distribution::UpstreamHashShard(_, _) => {
986 left = self.left().to_stream_with_dist_required(
987 &RequiredDist::shard_by_key(
988 self.left().schema().len(),
989 &predicate.left_eq_indexes(),
990 ),
991 ctx,
992 )?;
993 let left_dist = left.distribution();
994 match left_dist {
995 Distribution::HashShard(_) => {
996 let right_dist = l2r.rewrite_required_distribution(
997 &RequiredDist::PhysicalDist(left_dist.clone()),
998 );
999 right = right_dist.streaming_enforce_if_not_satisfies(right)?
1000 }
1001 Distribution::UpstreamHashShard(_, _) => {
1002 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
1003 .streaming_enforce_if_not_satisfies(left)?;
1004 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
1005 .streaming_enforce_if_not_satisfies(right)?;
1006 }
1007 _ => unreachable!(),
1008 }
1009 }
1010 _ => unreachable!(),
1011 }
1012 Ok((left, right))
1013 }
1014
1015 fn to_stream_hash_join(
1016 &self,
1017 predicate: EqJoinPredicate,
1018 ctx: &mut ToStreamContext,
1019 ) -> Result<StreamPlanRef> {
1020 use super::stream::prelude::*;
1021
1022 assert!(predicate.has_eq());
1023 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1024
1025 let mut core = self.core.clone_with_inputs(left, right);
1026 core.on = generic::JoinOn::EqPredicate(predicate);
1027
1028 let stream_hash_join = StreamHashJoin::new(core.clone())?;
1037 let predicate = stream_hash_join.eq_join_predicate().clone();
1038
1039 let force_filter_inside_join = self
1040 .base
1041 .ctx()
1042 .session_ctx()
1043 .config()
1044 .streaming_force_filter_inside_join();
1045
1046 let pull_filter = self.join_type() == JoinType::Inner
1047 && stream_hash_join.eq_join_predicate().has_non_eq()
1048 && stream_hash_join.inequality_pairs().is_empty()
1049 && (!force_filter_inside_join);
1050 if pull_filter {
1051 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
1052
1053 let mut core = core;
1054 core.output_indices = default_indices.clone();
1055 let eq_cond = EqJoinPredicate::new(
1057 Condition::true_cond(),
1058 predicate.eq_keys().to_vec(),
1059 self.left().schema().len(),
1060 self.right().schema().len(),
1061 );
1062 core.on = generic::JoinOn::EqPredicate(eq_cond);
1063 let hash_join = StreamHashJoin::new(core)?.into();
1064 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
1065 let plan = StreamFilter::new(logical_filter).into();
1066 if self.output_indices() != &default_indices {
1067 let logical_project = generic::Project::with_mapping(
1068 plan,
1069 ColIndexMapping::with_remaining_columns(
1070 self.output_indices(),
1071 self.internal_column_num(),
1072 ),
1073 );
1074 Ok(StreamProject::new(logical_project).into())
1075 } else {
1076 Ok(plan)
1077 }
1078 } else {
1079 Ok(stream_hash_join.into())
1080 }
1081 }
1082
1083 pub fn should_be_temporal_join(&self) -> bool {
1084 self.temporal_join_on().is_some()
1085 }
1086
1087 fn temporal_join_on(&self) -> Option<TemporalJoinScan<'_>> {
1088 if let Some(logical_scan) = self.core.right.as_logical_scan() {
1089 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
1090 .then_some(TemporalJoinScan(logical_scan))
1091 } else {
1092 None
1093 }
1094 }
1095
1096 fn should_be_stream_temporal_join<'a>(
1097 &'a self,
1098 ctx: &ToStreamContext,
1099 ) -> Result<Option<TemporalJoinScan<'a>>> {
1100 Ok(if let Some(scan) = self.temporal_join_on() {
1101 if let BackfillType::SnapshotBackfill = ctx.backfill_type() {
1102 return Err(RwError::from(ErrorCode::NotSupported(
1103 "Temporal join with snapshot backfill not supported".into(),
1104 "Please use arrangement backfill".into(),
1105 )));
1106 }
1107 if scan.cross_database() {
1108 return Err(RwError::from(ErrorCode::NotSupported(
1109 "Temporal join requires the lookup table to be in the same database as the stream source table".into(),
1110 "Please ensure both tables are in the same database".into(),
1111 )));
1112 }
1113 Some(scan)
1114 } else {
1115 None
1116 })
1117 }
1118
1119 fn to_stream_temporal_join_with_index_selection(
1120 &self,
1121 logical_scan: TemporalJoinScan<'_>,
1122 predicate: EqJoinPredicate,
1123 ctx: &mut ToStreamContext,
1124 ) -> Result<StreamPlanRef> {
1125 let mut result_plan: Result<StreamTemporalJoin> =
1127 self.to_stream_temporal_join(logical_scan, predicate.clone(), ctx);
1128 if let Ok(temporal_join) = &result_plan
1130 && temporal_join.eq_join_predicate().eq_indexes().len()
1131 == logical_scan.primary_key().len()
1132 {
1133 return result_plan.map(|x| x.into());
1134 }
1135 if self
1136 .core
1137 .ctx()
1138 .session_ctx()
1139 .config()
1140 .enable_index_selection()
1141 {
1142 let indexes = logical_scan.table_indexes();
1143 for index in indexes {
1144 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1146 let index_scan: PlanRef = index_scan.into();
1147 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1148 if let Ok(temporal_join) = that.to_stream_temporal_join(
1149 that.temporal_join_on().expect(
1150 "index scan created from temporal join scan must also be temporal join",
1151 ),
1152 predicate.clone(),
1153 ctx,
1154 ) {
1155 match &result_plan {
1156 Err(_) => result_plan = Ok(temporal_join),
1157 Ok(prev_temporal_join) => {
1158 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1160 < temporal_join.eq_join_predicate().eq_indexes().len()
1161 {
1162 result_plan = Ok(temporal_join)
1163 }
1164 }
1165 }
1166 }
1167 }
1168 }
1169 }
1170
1171 result_plan.map(|x| x.into())
1172 }
1173
1174 fn temporal_join_scan_predicate_pull_up(
1175 logical_scan: TemporalJoinScan<'_>,
1176 predicate: EqJoinPredicate,
1177 output_indices: &[usize],
1178 left_schema_len: usize,
1179 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1180 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1182 let o2r = if let Some(project_expr) = project_expr {
1184 project_expr
1185 .into_iter()
1186 .map(|x| x.as_input_ref().unwrap().index)
1187 .collect_vec()
1188 } else {
1189 (0..logical_scan.output_col_idx().len()).collect_vec()
1190 };
1191 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1192 offset: left_schema_len,
1193 mapping: o2r.clone(),
1194 };
1195
1196 let new_eq_cond = predicate
1197 .eq_cond()
1198 .rewrite_expr(&mut join_predicate_rewriter);
1199
1200 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1201 offset: left_schema_len,
1202 };
1203
1204 let new_other_cond = predicate
1205 .other_cond()
1206 .clone()
1207 .rewrite_expr(&mut join_predicate_rewriter)
1208 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1209
1210 let new_join_on = new_eq_cond.and(new_other_cond);
1211
1212 let new_predicate = EqJoinPredicate::create(
1213 left_schema_len,
1214 new_scan.schema().len(),
1215 new_join_on.clone(),
1216 );
1217
1218 let new_join_output_indices = output_indices
1221 .iter()
1222 .map(|&x| {
1223 if x < left_schema_len {
1224 x
1225 } else {
1226 o2r[x - left_schema_len] + left_schema_len
1227 }
1228 })
1229 .collect_vec();
1230
1231 let new_stream_table_scan =
1232 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1233 Ok((
1234 new_stream_table_scan,
1235 new_predicate,
1236 new_join_on,
1237 new_join_output_indices,
1238 ))
1239 }
1240
1241 fn to_stream_temporal_join(
1242 &self,
1243 logical_scan: TemporalJoinScan<'_>,
1244 predicate: EqJoinPredicate,
1245 ctx: &mut ToStreamContext,
1246 ) -> Result<StreamTemporalJoin> {
1247 use super::stream::prelude::*;
1248
1249 assert!(predicate.has_eq());
1250
1251 let table = logical_scan.table();
1252 let output_column_ids = logical_scan.output_column_ids();
1253
1254 let order_col_ids = table.order_column_ids();
1257 let dist_key = table.distribution_key.clone();
1258
1259 let mut dist_key_in_order_key_pos = vec![];
1260 for d in dist_key {
1261 let pos = table
1262 .order_column_indices()
1263 .position(|x| x == d)
1264 .expect("dist_key must in order_key");
1265 dist_key_in_order_key_pos.push(pos);
1266 }
1267 let shortest_prefix_len = dist_key_in_order_key_pos
1269 .iter()
1270 .max()
1271 .map_or(0, |pos| pos + 1);
1272
1273 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1275 for order_col_id in order_col_ids {
1276 let mut found = false;
1277 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1278 if order_col_id == output_column_ids[eq_idx] {
1279 reorder_idx.push(i);
1280 found = true;
1281 break;
1282 }
1283 }
1284 if !found {
1285 break;
1286 }
1287 }
1288 if reorder_idx.len() < shortest_prefix_len {
1289 return Err(RwError::from(ErrorCode::NotSupported(
1290 "Temporal join requires the equivalence join condition includes the key columns that form the distribution key of the lookup table".into(),
1291 concat!(
1292 "Use DESCRIBE <table_name> to view the table's key information.\n",
1293 "You can create an index on the lookup table to facilitate the temporal join if necessary."
1294 ).into(),
1295 )));
1296 }
1297 let lookup_prefix_len = reorder_idx.len();
1298 let predicate = predicate.reorder(&reorder_idx);
1299
1300 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1301 RequiredDist::single()
1302 } else {
1303 let left_eq_indexes = predicate.left_eq_indexes();
1304 let left_dist_key = dist_key_in_order_key_pos
1305 .iter()
1306 .map(|pos| left_eq_indexes[*pos])
1307 .collect_vec();
1308
1309 RequiredDist::hash_shard(&left_dist_key)
1310 };
1311
1312 let left = self.left().to_stream(ctx)?;
1313 let left = required_dist.stream_enforce(left);
1315
1316 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1317 Self::temporal_join_scan_predicate_pull_up(
1318 logical_scan,
1319 predicate,
1320 self.output_indices(),
1321 self.left().schema().len(),
1322 )?;
1323
1324 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1325 if !new_predicate.has_eq() {
1326 return Err(RwError::from(ErrorCode::NotSupported(
1327 "Temporal join requires a non trivial join condition".into(),
1328 "Please remove the false condition of the join".into(),
1329 )));
1330 }
1331
1332 let new_logical_join = generic::Join::new(
1334 left,
1335 right,
1336 new_join_on,
1337 self.join_type(),
1338 new_join_output_indices,
1339 );
1340
1341 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1342
1343 let mut new_logical_join = new_logical_join;
1344 new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1345 StreamTemporalJoin::new(new_logical_join, false)
1346 }
1347
1348 fn to_stream_nested_loop_temporal_join(
1349 &self,
1350 logical_scan: TemporalJoinScan<'_>,
1351 predicate: EqJoinPredicate,
1352 ctx: &mut ToStreamContext,
1353 ) -> Result<StreamPlanRef> {
1354 use super::stream::prelude::*;
1355 assert!(!predicate.has_eq());
1356
1357 let left = self.left().to_stream_with_dist_required(
1358 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1359 ctx,
1360 )?;
1361 assert!(left.as_stream_exchange().is_some());
1362
1363 if self.join_type() != JoinType::Inner {
1364 return Err(RwError::from(ErrorCode::NotSupported(
1365 "Temporal join requires an inner join".into(),
1366 "Please use an inner join".into(),
1367 )));
1368 }
1369
1370 if !left.append_only() {
1371 return Err(RwError::from(ErrorCode::NotSupported(
1372 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1373 "Please ensure the left hash side is append only".into(),
1374 )));
1375 }
1376
1377 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1378 Self::temporal_join_scan_predicate_pull_up(
1379 logical_scan,
1380 predicate,
1381 self.output_indices(),
1382 self.left().schema().len(),
1383 )?;
1384
1385 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1386
1387 let new_logical_join = generic::Join::new(
1389 left,
1390 right,
1391 new_join_on,
1392 self.join_type(),
1393 new_join_output_indices,
1394 );
1395
1396 let mut new_logical_join = new_logical_join;
1397 new_logical_join.on = generic::JoinOn::EqPredicate(new_predicate);
1398 Ok(StreamTemporalJoin::new(new_logical_join, true)?.into())
1399 }
1400
1401 fn to_stream_dynamic_filter(
1402 &self,
1403 predicate: Condition,
1404 ctx: &mut ToStreamContext,
1405 ) -> Result<Option<StreamPlanRef>> {
1406 use super::stream::prelude::*;
1407
1408 let Some((left_key_idx, comparator)) = self.dynamic_filter_candidate(&predicate) else {
1412 return Ok(None);
1413 };
1414
1415 let left = self.left().to_stream(ctx)?.enforce_concrete_distribution();
1416 let right = self.right().to_stream_with_dist_required(
1417 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1418 ctx,
1419 )?;
1420
1421 assert!(right.as_stream_exchange().is_some());
1422 assert_eq!(
1423 *Itertools::exactly_one(right.inputs().iter())
1424 .unwrap()
1425 .distribution(),
1426 Distribution::Single
1427 );
1428
1429 let core = DynamicFilter::new(comparator, left_key_idx, left, right);
1430 let plan = StreamDynamicFilter::new(core)?.into();
1431 if self
1433 .output_indices()
1434 .iter()
1435 .copied()
1436 .ne(0..self.left().schema().len())
1437 {
1438 let logical_project = generic::Project::with_mapping(
1441 plan,
1442 ColIndexMapping::with_remaining_columns(
1443 self.output_indices(),
1444 self.left().schema().len(),
1445 ),
1446 );
1447 Ok(Some(StreamProject::new(logical_project).into()))
1448 } else {
1449 Ok(Some(plan))
1450 }
1451 }
1452
1453 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<BatchPlanRef> {
1454 let predicate = EqJoinPredicate::create(
1455 self.left().schema().len(),
1456 self.right().schema().len(),
1457 self.on().clone(),
1458 );
1459 assert!(predicate.has_eq());
1460
1461 let join = self
1462 .core
1463 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1464
1465 Ok(self
1466 .to_batch_lookup_join(predicate, join)?
1467 .expect("Fail to convert to lookup join")
1468 .into())
1469 }
1470
1471 fn to_stream_asof_join(
1472 &self,
1473 predicate: EqJoinPredicate,
1474 ctx: &mut ToStreamContext,
1475 ) -> Result<StreamPlanRef> {
1476 use super::stream::prelude::*;
1477
1478 if predicate.eq_keys().is_empty() {
1479 return Err(ErrorCode::InvalidInputSyntax(
1480 "AsOf join requires at least 1 equal condition".to_owned(),
1481 )
1482 .into());
1483 }
1484
1485 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1486 let left_len = left.schema().len();
1487 let mut core = self.core.clone_with_inputs(left, right);
1488 core.on = generic::JoinOn::EqPredicate(predicate);
1489
1490 let inequality_desc = Self::get_inequality_desc_from_predicate(
1491 core.on
1492 .as_eq_predicate_ref()
1493 .expect("core predicate must exist")
1494 .other_cond()
1495 .clone(),
1496 left_len,
1497 )?;
1498
1499 Ok(StreamAsOfJoin::new(core, inequality_desc)?.into())
1500 }
1501
1502 fn to_batch_hash_join(
1504 &self,
1505 logical_join: generic::Join<BatchPlanRef>,
1506 predicate: EqJoinPredicate,
1507 ) -> Result<BatchPlanRef> {
1508 use super::batch::prelude::*;
1509
1510 let left_schema_len = logical_join.left.schema().len();
1511 let asof_desc = self
1512 .is_asof_join()
1513 .then(|| {
1514 Self::get_inequality_desc_from_predicate(
1515 predicate.other_cond().clone(),
1516 left_schema_len,
1517 )
1518 })
1519 .transpose()?;
1520
1521 let logical_join = generic::Join {
1522 on: generic::JoinOn::EqPredicate(predicate),
1523 ..logical_join
1524 };
1525 let batch_join = BatchHashJoin::new(logical_join, asof_desc);
1526 Ok(batch_join.into())
1527 }
1528
1529 pub fn get_inequality_desc_from_predicate(
1530 predicate: Condition,
1531 left_input_len: usize,
1532 ) -> Result<AsOfJoinDesc> {
1533 let expr: ExprImpl = predicate.into();
1534 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1535 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1536 {
1537 Ok(AsOfJoinDesc {
1538 left_idx: left_input_ref.index() as u32,
1539 right_idx: (right_input_ref.index() - left_input_len) as u32,
1540 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1541 })
1542 } else {
1543 bail!("inequal condition from the same side should be push down in optimizer");
1544 }
1545 } else {
1546 Err(ErrorCode::InvalidInputSyntax(
1547 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1548 )
1549 .into())
1550 }
1551 }
1552
1553 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1554 match expr_type {
1555 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1556 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1557 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1558 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1559 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1560 "Invalid comparison type: {}",
1561 expr_type.as_str_name()
1562 ))
1563 .into()),
1564 }
1565 }
1566}
1567
1568impl ToBatch for LogicalJoin {
1569 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1570 let predicate = EqJoinPredicate::create(
1571 self.left().schema().len(),
1572 self.right().schema().len(),
1573 self.on().clone(),
1574 );
1575
1576 let batch_join = self
1577 .core
1578 .clone_with_inputs(self.core.left.to_batch()?, self.core.right.to_batch()?);
1579
1580 let ctx = self.base.ctx();
1581 let config = ctx.session_ctx().config();
1582
1583 if predicate.has_eq() {
1584 if !predicate.eq_keys_are_type_aligned() {
1585 return Err(ErrorCode::InternalError(format!(
1586 "Join eq keys are not aligned for predicate: {predicate:?}"
1587 ))
1588 .into());
1589 }
1590 if config.batch_enable_lookup_join()
1591 && let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1592 predicate.clone(),
1593 batch_join.clone(),
1594 )?
1595 {
1596 return Ok(lookup_join.into());
1597 }
1598 self.to_batch_hash_join(batch_join, predicate)
1599 } else if self.is_asof_join() {
1600 Err(ErrorCode::InvalidInputSyntax(
1601 "AsOf join requires at least 1 equal condition".to_owned(),
1602 )
1603 .into())
1604 } else {
1605 Ok(BatchNestedLoopJoin::new(batch_join).into())
1607 }
1608 }
1609}
1610
1611impl ToStream for LogicalJoin {
1612 fn to_stream(
1613 &self,
1614 ctx: &mut ToStreamContext,
1615 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
1616 if self
1617 .on()
1618 .conjunctions
1619 .iter()
1620 .any(|cond| cond.count_nows() > 0)
1621 {
1622 return Err(ErrorCode::NotSupported(
1623 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1624 "please refer to https://docs.risingwave.com/processing/sql/temporal-filters for more information".to_owned()).into());
1625 }
1626
1627 let predicate = EqJoinPredicate::create(
1628 self.left().schema().len(),
1629 self.right().schema().len(),
1630 self.on().clone(),
1631 );
1632
1633 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1634 self.to_stream_asof_join(predicate, ctx)
1635 } else if predicate.has_eq() {
1636 if !predicate.eq_keys_are_type_aligned() {
1637 return Err(ErrorCode::InternalError(format!(
1638 "Join eq keys are not aligned for predicate: {predicate:?}"
1639 ))
1640 .into());
1641 }
1642
1643 if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1644 self.to_stream_temporal_join_with_index_selection(scan, predicate, ctx)
1645 } else {
1646 self.to_stream_hash_join(predicate, ctx)
1647 }
1648 } else if let Some(scan) = self.should_be_stream_temporal_join(ctx)? {
1649 self.to_stream_nested_loop_temporal_join(scan, predicate, ctx)
1650 } else if let Some(dynamic_filter) =
1651 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1652 {
1653 Ok(dynamic_filter)
1654 } else {
1655 Err(RwError::from(ErrorCode::NotSupported(
1656 "streaming nested-loop join".to_owned(),
1657 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1658 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1659 See also: https://docs.risingwave.com/processing/sql/dynamic-filters".to_owned(),
1660 )))
1661 }
1662 }
1663
1664 fn logical_rewrite_for_stream(
1665 &self,
1666 ctx: &mut RewriteStreamContext,
1667 ) -> Result<(PlanRef, ColIndexMapping)> {
1668 let eq_indexes = self.eq_indexes();
1669 let (logical_left, logical_right) = if eq_indexes.is_empty() {
1670 (self.left(), self.right())
1671 } else {
1672 let lhs_join_key_idx = eq_indexes.iter().map(|(l, _)| *l).collect_vec();
1673 if self.should_be_temporal_join() {
1674 (
1675 try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1676 self.right(),
1677 )
1678 } else {
1679 let rhs_join_key_idx = eq_indexes.iter().map(|(_, r)| *r).collect_vec();
1680 (
1681 try_enforce_locality_requirement(self.left(), &lhs_join_key_idx),
1682 try_enforce_locality_requirement(self.right(), &rhs_join_key_idx),
1683 )
1684 }
1685 };
1686
1687 let (left, left_col_change) = logical_left.logical_rewrite_for_stream(ctx)?;
1688 let left_len = left.schema().len();
1689 let (right, right_col_change) = logical_right.logical_rewrite_for_stream(ctx)?;
1690 let (join, out_col_change) = self.rewrite_with_left_right(
1691 left.clone(),
1692 left_col_change,
1693 right.clone(),
1694 right_col_change,
1695 );
1696
1697 let mapping = ColIndexMapping::with_remaining_columns(
1698 join.output_indices(),
1699 join.internal_column_num(),
1700 );
1701
1702 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1703 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1704
1705 let mut left_to_add = left
1707 .expect_stream_key()
1708 .iter()
1709 .cloned()
1710 .filter(|i| l2o.try_map(*i).is_none())
1711 .collect_vec();
1712
1713 let mut right_to_add = right
1714 .expect_stream_key()
1715 .iter()
1716 .filter(|&&i| r2o.try_map(i).is_none())
1717 .map(|&i| i + left_len)
1718 .collect_vec();
1719
1720 let right_len = right.schema().len();
1723 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1724
1725 let either_or_both = self.core.add_which_join_key_to_pk();
1726
1727 for (lk, rk) in eq_predicate.eq_indexes() {
1728 match either_or_both {
1729 EitherOrBoth::Left(_) => {
1730 if l2o.try_map(lk).is_none() {
1731 left_to_add.push(lk);
1732 }
1733 }
1734 EitherOrBoth::Right(_) => {
1735 if r2o.try_map(rk).is_none() {
1736 right_to_add.push(rk + left_len)
1737 }
1738 }
1739 EitherOrBoth::Both(_, _) => {
1740 if l2o.try_map(lk).is_none() {
1741 left_to_add.push(lk);
1742 }
1743 if r2o.try_map(rk).is_none() {
1744 right_to_add.push(rk + left_len)
1745 }
1746 }
1747 };
1748 }
1749 let left_to_add = left_to_add.into_iter().unique();
1750 let right_to_add = right_to_add.into_iter().unique();
1751 let mut new_output_indices = join.output_indices().clone();
1754 if !join.is_right_join() {
1755 new_output_indices.extend(left_to_add);
1756 }
1757 if !join.is_left_join() {
1758 new_output_indices.extend(right_to_add);
1759 }
1760
1761 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1762
1763 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1764 let l2o = join_with_pk
1767 .core
1768 .l2i_col_mapping()
1769 .composite(&join_with_pk.core.i2o_col_mapping());
1770 let r2o = join_with_pk
1771 .core
1772 .r2i_col_mapping()
1773 .composite(&join_with_pk.core.i2o_col_mapping());
1774 let mut left_right_keys = join_with_pk
1775 .left()
1776 .expect_stream_key()
1777 .iter()
1778 .map(|i| l2o.map(*i))
1779 .collect_vec();
1780 left_right_keys.extend(
1781 join_with_pk
1782 .right()
1783 .expect_stream_key()
1784 .iter()
1785 .map(|i| r2o.map(*i)),
1786 );
1787 left_right_keys.extend(
1788 eq_predicate
1789 .eq_indexes()
1790 .iter()
1791 .flat_map(|(lk, rk)| [l2o.map(*lk), r2o.map(*rk)]),
1792 );
1793 let left_right_keys = left_right_keys.into_iter().unique().collect_vec();
1794 let plan: PlanRef = join_with_pk.into();
1795 LogicalFilter::filter_out_all_null_keys(plan, &left_right_keys)
1796 } else {
1797 join_with_pk.into()
1798 };
1799
1800 Ok((plan, out_col_change))
1802 }
1803
1804 fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1805 if !self.temporal_filter_candidate() {
1807 return None;
1808 }
1809
1810 let o2i_mapping = self.core.o2i_col_mapping();
1812 let left_input_columns = columns
1813 .iter()
1814 .map(|&col| o2i_mapping.try_map(col))
1815 .collect::<Option<Vec<usize>>>()?;
1816 if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) {
1817 return Some(
1818 self.clone_with_left_right(better_left_plan, self.right())
1819 .into(),
1820 );
1821 }
1822 None
1823 }
1824}
1825
1826#[cfg(test)]
1827mod tests {
1828
1829 use std::collections::HashSet;
1830
1831 use risingwave_common::catalog::{Field, Schema};
1832 use risingwave_common::types::{DataType, Datum};
1833 use risingwave_pb::expr::expr_node::Type;
1834
1835 use super::*;
1836 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1837 use crate::optimizer::optimizer_context::OptimizerContext;
1838 use crate::optimizer::plan_node::LogicalValues;
1839 use crate::optimizer::property::FunctionalDependency;
1840
1841 #[tokio::test]
1855 async fn test_prune_join() {
1856 let ty = DataType::Int32;
1857 let ctx = OptimizerContext::mock();
1858 let fields: Vec<Field> = (1..7)
1859 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1860 .collect();
1861 let left = LogicalValues::new(
1862 vec![],
1863 Schema {
1864 fields: fields[0..3].to_vec(),
1865 },
1866 ctx.clone(),
1867 );
1868 let right = LogicalValues::new(
1869 vec![],
1870 Schema {
1871 fields: fields[3..6].to_vec(),
1872 },
1873 ctx,
1874 );
1875 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1876 FunctionCall::new(
1877 Type::Equal,
1878 vec![
1879 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1880 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1881 ],
1882 )
1883 .unwrap(),
1884 ));
1885 let join_type = JoinType::Inner;
1886 let join: PlanRef = LogicalJoin::new(
1887 left.into(),
1888 right.into(),
1889 join_type,
1890 Condition::with_expr(on),
1891 )
1892 .into();
1893
1894 let required_cols = vec![2, 3];
1896 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1897
1898 let join = plan.as_logical_join().unwrap();
1900 assert_eq!(join.schema().fields().len(), 2);
1901 assert_eq!(join.schema().fields()[0], fields[2]);
1902 assert_eq!(join.schema().fields()[1], fields[3]);
1903
1904 let expr: ExprImpl = join.on().clone().into();
1905 let call = expr.as_function_call().unwrap();
1906 assert_eq_input_ref!(&call.inputs()[0], 0);
1907 assert_eq_input_ref!(&call.inputs()[1], 2);
1908
1909 let left = join.left();
1910 let left = left.as_logical_values().unwrap();
1911 assert_eq!(left.schema().fields(), &fields[1..3]);
1912 let right = join.right();
1913 let right = right.as_logical_values().unwrap();
1914 assert_eq!(right.schema().fields(), &fields[3..4]);
1915 }
1916
1917 #[tokio::test]
1919 async fn test_prune_semi_join() {
1920 let ty = DataType::Int32;
1921 let ctx = OptimizerContext::mock();
1922 let fields: Vec<Field> = (1..7)
1923 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1924 .collect();
1925 let left = LogicalValues::new(
1926 vec![],
1927 Schema {
1928 fields: fields[0..3].to_vec(),
1929 },
1930 ctx.clone(),
1931 );
1932 let right = LogicalValues::new(
1933 vec![],
1934 Schema {
1935 fields: fields[3..6].to_vec(),
1936 },
1937 ctx,
1938 );
1939 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1940 FunctionCall::new(
1941 Type::Equal,
1942 vec![
1943 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1944 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1945 ],
1946 )
1947 .unwrap(),
1948 ));
1949 for join_type in [
1950 JoinType::LeftSemi,
1951 JoinType::RightSemi,
1952 JoinType::LeftAnti,
1953 JoinType::RightAnti,
1954 ] {
1955 let join = LogicalJoin::new(
1956 left.clone().into(),
1957 right.clone().into(),
1958 join_type,
1959 Condition::with_expr(on.clone()),
1960 );
1961
1962 let offset = if join.is_right_join() { 3 } else { 0 };
1963 let join: PlanRef = join.into();
1964 let required_cols = vec![0];
1966 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1968 let as_plan = plan.as_logical_join().unwrap();
1969 assert_eq!(as_plan.schema().fields().len(), 1);
1971 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1972
1973 let required_cols = vec![0, 1, 2];
1975 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1977 let as_plan = plan.as_logical_join().unwrap();
1978 assert_eq!(as_plan.schema().fields().len(), 3);
1980 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1981 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1982 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1983 }
1984 }
1985
1986 #[tokio::test]
1999 async fn test_prune_join_no_project() {
2000 let ty = DataType::Int32;
2001 let ctx = OptimizerContext::mock();
2002 let fields: Vec<Field> = (1..7)
2003 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2004 .collect();
2005 let left = LogicalValues::new(
2006 vec![],
2007 Schema {
2008 fields: fields[0..3].to_vec(),
2009 },
2010 ctx.clone(),
2011 );
2012 let right = LogicalValues::new(
2013 vec![],
2014 Schema {
2015 fields: fields[3..6].to_vec(),
2016 },
2017 ctx,
2018 );
2019 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2020 FunctionCall::new(
2021 Type::Equal,
2022 vec![
2023 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2024 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2025 ],
2026 )
2027 .unwrap(),
2028 ));
2029 let join_type = JoinType::Inner;
2030 let join: PlanRef = LogicalJoin::new(
2031 left.into(),
2032 right.into(),
2033 join_type,
2034 Condition::with_expr(on),
2035 )
2036 .into();
2037
2038 let required_cols = vec![1, 3];
2040 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2041
2042 let join = plan.as_logical_join().unwrap();
2044 assert_eq!(join.schema().fields().len(), 2);
2045 assert_eq!(join.schema().fields()[0], fields[1]);
2046 assert_eq!(join.schema().fields()[1], fields[3]);
2047
2048 let expr: ExprImpl = join.on().clone().into();
2049 let call = expr.as_function_call().unwrap();
2050 assert_eq_input_ref!(&call.inputs()[0], 0);
2051 assert_eq_input_ref!(&call.inputs()[1], 1);
2052
2053 let left = join.left();
2054 let left = left.as_logical_values().unwrap();
2055 assert_eq!(left.schema().fields(), &fields[1..2]);
2056 let right = join.right();
2057 let right = right.as_logical_values().unwrap();
2058 assert_eq!(right.schema().fields(), &fields[3..4]);
2059 }
2060
2061 #[tokio::test]
2075 async fn test_join_to_batch() {
2076 let ctx = OptimizerContext::mock();
2077 let fields: Vec<Field> = (1..7)
2078 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
2079 .collect();
2080 let left = LogicalValues::new(
2081 vec![],
2082 Schema {
2083 fields: fields[0..3].to_vec(),
2084 },
2085 ctx.clone(),
2086 );
2087 let right = LogicalValues::new(
2088 vec![],
2089 Schema {
2090 fields: fields[3..6].to_vec(),
2091 },
2092 ctx,
2093 );
2094
2095 fn input_ref(i: usize) -> ExprImpl {
2096 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
2097 }
2098 let eq_cond = ExprImpl::FunctionCall(Box::new(
2099 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
2100 ));
2101 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
2102 FunctionCall::new(
2103 Type::Equal,
2104 vec![
2105 input_ref(2),
2106 ExprImpl::Literal(Box::new(Literal::new(
2107 Datum::Some(42_i32.into()),
2108 DataType::Int32,
2109 ))),
2110 ],
2111 )
2112 .unwrap(),
2113 ));
2114 let on_cond = ExprImpl::FunctionCall(Box::new(
2116 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
2117 ));
2118
2119 let join_type = JoinType::Inner;
2120 let logical_join = LogicalJoin::new(
2121 left.into(),
2122 right.into(),
2123 join_type,
2124 Condition::with_expr(on_cond),
2125 );
2126
2127 let result = logical_join.to_batch().unwrap();
2129
2130 let hash_join = result.as_batch_hash_join().unwrap();
2132 assert_eq!(
2133 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2134 eq_cond
2135 );
2136 assert_eq!(
2137 *hash_join
2138 .eq_join_predicate()
2139 .non_eq_cond()
2140 .conjunctions
2141 .first()
2142 .unwrap(),
2143 non_eq_cond
2144 );
2145 }
2146
2147 #[tokio::test]
2160 #[ignore] async fn test_join_to_stream() {
2163 }
2231 #[tokio::test]
2245 async fn test_join_column_prune_with_order_required() {
2246 let ty = DataType::Int32;
2247 let ctx = OptimizerContext::mock();
2248 let fields: Vec<Field> = (1..7)
2249 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2250 .collect();
2251 let left = LogicalValues::new(
2252 vec![],
2253 Schema {
2254 fields: fields[0..3].to_vec(),
2255 },
2256 ctx.clone(),
2257 );
2258 let right = LogicalValues::new(
2259 vec![],
2260 Schema {
2261 fields: fields[3..6].to_vec(),
2262 },
2263 ctx,
2264 );
2265 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2266 FunctionCall::new(
2267 Type::Equal,
2268 vec![
2269 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2270 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2271 ],
2272 )
2273 .unwrap(),
2274 ));
2275 let join_type = JoinType::Inner;
2276 let join: PlanRef = LogicalJoin::new(
2277 left.into(),
2278 right.into(),
2279 join_type,
2280 Condition::with_expr(on),
2281 )
2282 .into();
2283
2284 let required_cols = vec![3, 2];
2286 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2287
2288 let join = plan.as_logical_join().unwrap();
2290 assert_eq!(join.schema().fields().len(), 2);
2291 assert_eq!(join.schema().fields()[0], fields[3]);
2292 assert_eq!(join.schema().fields()[1], fields[2]);
2293
2294 let expr: ExprImpl = join.on().clone().into();
2295 let call = expr.as_function_call().unwrap();
2296 assert_eq_input_ref!(&call.inputs()[0], 0);
2297 assert_eq_input_ref!(&call.inputs()[1], 2);
2298
2299 let left = join.left();
2300 let left = left.as_logical_values().unwrap();
2301 assert_eq!(left.schema().fields(), &fields[1..3]);
2302 let right = join.right();
2303 let right = right.as_logical_values().unwrap();
2304 assert_eq!(right.schema().fields(), &fields[3..4]);
2305 }
2306
2307 #[tokio::test]
2308 async fn fd_derivation_inner_outer_join() {
2309 let ctx = OptimizerContext::mock();
2332 let left = {
2333 let fields: Vec<Field> = vec![
2334 Field::with_name(DataType::Int32, "l0"),
2335 Field::with_name(DataType::Int32, "l1"),
2336 ];
2337 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2338 values
2340 .base
2341 .functional_dependency_mut()
2342 .add_functional_dependency_by_column_indices(&[0], &[1]);
2343 values
2344 };
2345 let right = {
2346 let fields: Vec<Field> = vec![
2347 Field::with_name(DataType::Int32, "r0"),
2348 Field::with_name(DataType::Int32, "r1"),
2349 Field::with_name(DataType::Int32, "r2"),
2350 ];
2351 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2352 values
2354 .base
2355 .functional_dependency_mut()
2356 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2357 values
2358 };
2359 let on: ExprImpl = FunctionCall::new(
2361 Type::And,
2362 vec![
2363 FunctionCall::new(
2364 Type::Equal,
2365 vec![
2366 InputRef::new(0, DataType::Int32).into(),
2367 ExprImpl::literal_int(0),
2368 ],
2369 )
2370 .unwrap()
2371 .into(),
2372 FunctionCall::new(
2373 Type::Equal,
2374 vec![
2375 InputRef::new(1, DataType::Int32).into(),
2376 InputRef::new(3, DataType::Int32).into(),
2377 ],
2378 )
2379 .unwrap()
2380 .into(),
2381 ],
2382 )
2383 .unwrap()
2384 .into();
2385 let expected_fd_set = [
2386 (
2387 JoinType::Inner,
2388 [
2389 FunctionalDependency::with_indices(5, &[0], &[1]),
2391 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2393 FunctionalDependency::with_indices(5, &[], &[0]),
2395 FunctionalDependency::with_indices(5, &[1], &[3]),
2397 FunctionalDependency::with_indices(5, &[3], &[1]),
2398 ]
2399 .into_iter()
2400 .collect::<HashSet<_>>(),
2401 ),
2402 (JoinType::FullOuter, HashSet::new()),
2403 (
2404 JoinType::RightOuter,
2405 [
2406 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2408 ]
2409 .into_iter()
2410 .collect::<HashSet<_>>(),
2411 ),
2412 (
2413 JoinType::LeftOuter,
2414 [
2415 FunctionalDependency::with_indices(5, &[0], &[1]),
2417 ]
2418 .into_iter()
2419 .collect::<HashSet<_>>(),
2420 ),
2421 (
2422 JoinType::LeftSemi,
2423 [
2424 FunctionalDependency::with_indices(2, &[0], &[1]),
2426 ]
2427 .into_iter()
2428 .collect::<HashSet<_>>(),
2429 ),
2430 (
2431 JoinType::LeftAnti,
2432 [
2433 FunctionalDependency::with_indices(2, &[0], &[1]),
2435 ]
2436 .into_iter()
2437 .collect::<HashSet<_>>(),
2438 ),
2439 (
2440 JoinType::RightSemi,
2441 [
2442 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2444 ]
2445 .into_iter()
2446 .collect::<HashSet<_>>(),
2447 ),
2448 (
2449 JoinType::RightAnti,
2450 [
2451 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2453 ]
2454 .into_iter()
2455 .collect::<HashSet<_>>(),
2456 ),
2457 ];
2458
2459 for (join_type, expected_res) in expected_fd_set {
2460 let join = LogicalJoin::new(
2461 left.clone().into(),
2462 right.clone().into(),
2463 join_type,
2464 Condition::with_expr(on.clone()),
2465 );
2466 let fd_set = join
2467 .functional_dependency()
2468 .as_dependencies()
2469 .iter()
2470 .cloned()
2471 .collect::<HashSet<_>>();
2472 assert_eq!(fd_set, expected_res);
2473 }
2474 }
2475}