1use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31 ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeBinary, PredicatePushdown,
32 StreamHashJoin, StreamProject, ToBatch, ToStream, generic,
33};
34use crate::error::{ErrorCode, Result, RwError};
35use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::DynamicFilter;
38use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
39use crate::optimizer::plan_node::utils::IndicesDisplay;
40use crate::optimizer::plan_node::{
41 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
42 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
43 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
44};
45use crate::optimizer::plan_visitor::LogicalCardinalityExt;
46use crate::optimizer::property::{Distribution, Order, RequiredDist};
47use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
48
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct LogicalJoin {
57 pub base: PlanBase<Logical>,
58 core: generic::Join<PlanRef>,
59}
60
61impl Distill for LogicalJoin {
62 fn distill<'a>(&self) -> XmlNode<'a> {
63 let verbose = self.base.ctx().is_explain_verbose();
64 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
65 vec.push(("type", Pretty::debug(&self.join_type())));
66
67 let concat_schema = self.core.concat_schema();
68 let cond = Pretty::debug(&ConditionDisplay {
69 condition: self.on(),
70 input_schema: &concat_schema,
71 });
72 vec.push(("on", cond));
73
74 if verbose {
75 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
76 vec.push(("output", data));
77 }
78
79 childless_record("LogicalJoin", vec)
80 }
81}
82
83impl LogicalJoin {
84 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
85 let core = generic::Join::with_full_output(left, right, join_type, on);
86 Self::with_core(core)
87 }
88
89 pub(crate) fn with_output_indices(
90 left: PlanRef,
91 right: PlanRef,
92 join_type: JoinType,
93 on: Condition,
94 output_indices: Vec<usize>,
95 ) -> Self {
96 let core = generic::Join::new(left, right, on, join_type, output_indices);
97 Self::with_core(core)
98 }
99
100 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
101 let base = PlanBase::new_logical_with_core(&core);
102 LogicalJoin { base, core }
103 }
104
105 pub fn create(
106 left: PlanRef,
107 right: PlanRef,
108 join_type: JoinType,
109 on_clause: ExprImpl,
110 ) -> PlanRef {
111 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
112 }
113
114 pub fn internal_column_num(&self) -> usize {
115 self.core.internal_column_num()
116 }
117
118 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
119 self.core.i2l_col_mapping_ignore_join_type()
120 }
121
122 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
123 self.core.i2r_col_mapping_ignore_join_type()
124 }
125
126 pub fn on(&self) -> &Condition {
128 &self.core.on
129 }
130
131 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
133 let input_refs = self
134 .core
135 .on
136 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
137 let index_group = input_refs
138 .ones()
139 .chunk_by(|i| *i < self.core.left.schema().len());
140 let left_index = index_group
141 .into_iter()
142 .next()
143 .map_or(vec![], |group| group.1.collect_vec());
144 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
145 group
146 .1
147 .map(|i| i - self.core.left.schema().len())
148 .collect_vec()
149 });
150 (left_index, right_index)
151 }
152
153 pub fn join_type(&self) -> JoinType {
155 self.core.join_type
156 }
157
158 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
160 self.core.eq_indexes()
161 }
162
163 pub fn output_indices(&self) -> &Vec<usize> {
165 &self.core.output_indices
166 }
167
168 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
170 Self::with_core(generic::Join {
171 output_indices,
172 ..self.core.clone()
173 })
174 }
175
176 pub fn clone_with_cond(&self, on: Condition) -> Self {
178 Self::with_core(generic::Join {
179 on,
180 ..self.core.clone()
181 })
182 }
183
184 pub fn is_left_join(&self) -> bool {
185 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
186 }
187
188 pub fn is_right_join(&self) -> bool {
189 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
190 }
191
192 pub fn is_full_out(&self) -> bool {
193 self.core.is_full_out()
194 }
195
196 pub fn is_asof_join(&self) -> bool {
197 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
198 }
199
200 pub fn output_indices_are_trivial(&self) -> bool {
201 self.output_indices() == &(0..self.internal_column_num()).collect_vec()
202 }
203
204 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
209 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
210 JoinType::LeftOuter => (false, true),
211 JoinType::RightOuter => (true, false),
212 JoinType::FullOuter => (true, true),
213 _ => return join_type,
214 };
215
216 for expr in &predicate.conjunctions {
217 if let ExprImpl::FunctionCall(func) = expr {
218 match func.func_type() {
219 ExprType::Equal
220 | ExprType::NotEqual
221 | ExprType::LessThan
222 | ExprType::LessThanOrEqual
223 | ExprType::GreaterThan
224 | ExprType::GreaterThanOrEqual => {
225 for input in func.inputs() {
226 if let ExprImpl::InputRef(input) = input {
227 let idx = input.index;
228 if idx < left_col_num {
229 gen_null_in_left = false;
230 } else {
231 gen_null_in_right = false;
232 }
233 }
234 }
235 }
236 _ => {}
237 };
238 }
239 }
240
241 match (gen_null_in_left, gen_null_in_right) {
242 (true, true) => JoinType::FullOuter,
243 (true, false) => JoinType::RightOuter,
244 (false, true) => JoinType::LeftOuter,
245 (false, false) => JoinType::Inner,
246 }
247 }
248
249 fn to_batch_lookup_join_with_index_selection(
253 &self,
254 predicate: EqJoinPredicate,
255 logical_join: generic::Join<PlanRef>,
256 ) -> Result<Option<BatchLookupJoin>> {
257 match logical_join.join_type {
258 JoinType::Inner
259 | JoinType::LeftOuter
260 | JoinType::LeftSemi
261 | JoinType::LeftAnti
262 | JoinType::AsofInner
263 | JoinType::AsofLeftOuter => {}
264 _ => return Ok(None),
265 };
266
267 let right = self.right();
269 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
271 logical_scan
272 } else {
273 return Ok(None);
274 };
275
276 let mut result_plan = None;
277 if let Some(lookup_join) =
279 self.to_batch_lookup_join(predicate.clone(), logical_join.clone())?
280 {
281 result_plan = Some(lookup_join);
282 }
283
284 let indexes = logical_scan.indexes();
285 for index in indexes {
286 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
287 let index_scan: PlanRef = index_scan.into();
288 let that = self.clone_with_left_right(self.left(), index_scan.clone());
289 let mut new_logical_join = logical_join.clone();
290 new_logical_join.right = index_scan.to_batch().expect("index scan failed to batch");
291
292 if let Some(lookup_join) =
294 that.to_batch_lookup_join(predicate.clone(), new_logical_join)?
295 {
296 match &result_plan {
297 None => result_plan = Some(lookup_join),
298 Some(prev_lookup_join) => {
299 if prev_lookup_join.lookup_prefix_len()
301 < lookup_join.lookup_prefix_len()
302 {
303 result_plan = Some(lookup_join)
304 }
305 }
306 }
307 }
308 }
309 }
310
311 Ok(result_plan)
312 }
313
314 fn to_batch_lookup_join(
316 &self,
317 predicate: EqJoinPredicate,
318 logical_join: generic::Join<PlanRef>,
319 ) -> Result<Option<BatchLookupJoin>> {
320 match logical_join.join_type {
321 JoinType::Inner
322 | JoinType::LeftOuter
323 | JoinType::LeftSemi
324 | JoinType::LeftAnti
325 | JoinType::AsofInner
326 | JoinType::AsofLeftOuter => {}
327 _ => return Ok(None),
328 };
329
330 let right = self.right();
331 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
333 logical_scan
334 } else {
335 return Ok(None);
336 };
337 let table_desc = logical_scan.table_desc().clone();
338 let output_column_ids = logical_scan.output_column_ids();
339
340 let order_col_ids = table_desc.order_column_ids();
343 let order_key = table_desc.order_column_indices();
344 let dist_key = table_desc.distribution_key.clone();
345 let mut dist_key_in_order_key_pos = vec![];
347 for d in dist_key {
348 let pos = order_key
349 .iter()
350 .position(|&x| x == d)
351 .expect("dist_key must in order_key");
352 dist_key_in_order_key_pos.push(pos);
353 }
354 let shortest_prefix_len = dist_key_in_order_key_pos
356 .iter()
357 .max()
358 .map_or(0, |pos| pos + 1);
359
360 if shortest_prefix_len == 0 {
362 return Ok(None);
363 }
364
365 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
367 for order_col_id in order_col_ids {
368 let mut found = false;
369 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
370 if order_col_id == output_column_ids[eq_idx] {
371 reorder_idx.push(i);
372 found = true;
373 break;
374 }
375 }
376 if !found {
377 break;
378 }
379 }
380 if reorder_idx.len() < shortest_prefix_len {
381 return Ok(None);
382 }
383 let lookup_prefix_len = reorder_idx.len();
384 let predicate = predicate.reorder(&reorder_idx);
385
386 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
388 let o2r = if let Some(project_expr) = project_expr {
390 project_expr
391 .into_iter()
392 .map(|x| x.as_input_ref().unwrap().index)
393 .collect_vec()
394 } else {
395 (0..logical_scan.output_col_idx().len()).collect_vec()
396 };
397 let left_schema_len = logical_join.left.schema().len();
398
399 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
400 offset: left_schema_len,
401 mapping: o2r.clone(),
402 };
403
404 let new_eq_cond = predicate
405 .eq_cond()
406 .rewrite_expr(&mut join_predicate_rewriter);
407
408 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
409 offset: left_schema_len,
410 };
411
412 let new_other_cond = predicate
413 .other_cond()
414 .clone()
415 .rewrite_expr(&mut join_predicate_rewriter)
416 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
417
418 let new_join_on = new_eq_cond.and(new_other_cond);
419 let new_predicate = EqJoinPredicate::create(
420 left_schema_len,
421 new_scan.schema().len(),
422 new_join_on.clone(),
423 );
424
425 if !new_predicate.has_eq() {
428 return Ok(None);
429 }
430
431 let new_join_output_indices = logical_join
434 .output_indices
435 .iter()
436 .map(|&x| {
437 if x < left_schema_len {
438 x
439 } else {
440 o2r[x - left_schema_len] + left_schema_len
441 }
442 })
443 .collect_vec();
444
445 let new_scan_output_column_ids = new_scan.output_column_ids();
446 let as_of = new_scan.as_of.clone();
447
448 let new_logical_join = generic::Join::new(
450 logical_join.left,
451 new_scan.into(),
452 new_join_on,
453 logical_join.join_type,
454 new_join_output_indices,
455 );
456
457 let asof_desc = self
458 .is_asof_join()
459 .then(|| {
460 Self::get_inequality_desc_from_predicate(
461 predicate.other_cond().clone(),
462 left_schema_len,
463 )
464 })
465 .transpose()?;
466
467 Ok(Some(BatchLookupJoin::new(
468 new_logical_join,
469 new_predicate,
470 table_desc,
471 new_scan_output_column_ids,
472 lookup_prefix_len,
473 false,
474 as_of,
475 asof_desc,
476 )))
477 }
478
479 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
480 self.core.decompose()
481 }
482}
483
484impl PlanTreeNodeBinary for LogicalJoin {
485 fn left(&self) -> PlanRef {
486 self.core.left.clone()
487 }
488
489 fn right(&self) -> PlanRef {
490 self.core.right.clone()
491 }
492
493 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
494 Self::with_core(generic::Join {
495 left,
496 right,
497 ..self.core.clone()
498 })
499 }
500
501 fn rewrite_with_left_right(
502 &self,
503 left: PlanRef,
504 left_col_change: ColIndexMapping,
505 right: PlanRef,
506 right_col_change: ColIndexMapping,
507 ) -> (Self, ColIndexMapping) {
508 let (new_on, new_output_indices) = {
509 let (mut map, _) = left_col_change.clone().into_parts();
510 let (mut right_map, _) = right_col_change.clone().into_parts();
511 for i in right_map.iter_mut().flatten() {
512 *i += left.schema().len();
513 }
514 map.append(&mut right_map);
515 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
516
517 let new_output_indices = self
518 .output_indices()
519 .iter()
520 .map(|&i| mapping.map(i))
521 .collect::<Vec<_>>();
522 let new_on = self.on().clone().rewrite_expr(&mut mapping);
523 (new_on, new_output_indices)
524 };
525
526 let join = Self::with_output_indices(
527 left,
528 right,
529 self.join_type(),
530 new_on,
531 new_output_indices.clone(),
532 );
533
534 let new_i2o = ColIndexMapping::with_remaining_columns(
535 &new_output_indices,
536 join.internal_column_num(),
537 );
538
539 let old_o2i = self.core.o2i_col_mapping();
540
541 let old_o2l = old_o2i
542 .composite(&self.core.i2l_col_mapping())
543 .composite(&left_col_change);
544 let old_o2r = old_o2i
545 .composite(&self.core.i2r_col_mapping())
546 .composite(&right_col_change);
547 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
548 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
549
550 let out_col_change = old_o2l
551 .composite(&new_l2o)
552 .union(&old_o2r.composite(&new_r2o));
553 (join, out_col_change)
554 }
555}
556
557impl_plan_tree_node_for_binary! { LogicalJoin }
558
559impl ColPrunable for LogicalJoin {
560 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
561 let required_cols = required_cols
563 .iter()
564 .map(|i| self.output_indices()[*i])
565 .collect_vec();
566 let left_len = self.left().schema().fields.len();
567
568 let total_len = self.left().schema().len() + self.right().schema().len();
569 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
570
571 required_cols.iter().for_each(|&i| {
572 if self.is_right_join() {
573 resized_required_cols.insert(left_len + i);
574 } else {
575 resized_required_cols.insert(i);
576 }
577 });
578
579 let mut visitor = CollectInputRef::new(resized_required_cols);
582 self.on().visit_expr(&mut visitor);
583 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
584
585 let mut left_required_cols = Vec::new();
586 let mut right_required_cols = Vec::new();
587 left_right_required_cols.iter().for_each(|&i| {
588 if i < left_len {
589 left_required_cols.push(i);
590 } else {
591 right_required_cols.push(i - left_len);
592 }
593 });
594
595 let mut on = self.on().clone();
596 let mut mapping =
597 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
598 on = on.rewrite_expr(&mut mapping);
599
600 let new_output_indices = {
601 let required_inputs_in_output = if self.is_left_join() {
602 &left_required_cols
603 } else if self.is_right_join() {
604 &right_required_cols
605 } else {
606 &left_right_required_cols
607 };
608
609 let mapping =
610 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
611 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
612 };
613
614 LogicalJoin::with_output_indices(
615 self.left().prune_col(&left_required_cols, ctx),
616 self.right().prune_col(&right_required_cols, ctx),
617 self.join_type(),
618 on,
619 new_output_indices,
620 )
621 .into()
622 }
623}
624
625impl ExprRewritable for LogicalJoin {
626 fn has_rewritable_expr(&self) -> bool {
627 true
628 }
629
630 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
631 let mut core = self.core.clone();
632 core.rewrite_exprs(r);
633 Self {
634 base: self.base.clone_with_new_plan_id(),
635 core,
636 }
637 .into()
638 }
639}
640
641impl ExprVisitable for LogicalJoin {
642 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
643 self.core.visit_exprs(v);
644 }
645}
646
647fn derive_predicate_from_eq_condition(
665 expr: &ExprImpl,
666 eq_condition: &EqJoinPredicate,
667 col_num: usize,
668 expr_is_left: bool,
669) -> Option<ExprImpl> {
670 if expr.is_impure() {
671 return None;
672 }
673 let eq_indices = eq_condition
674 .eq_indexes_typed()
675 .iter()
676 .filter_map(|(l, r)| {
677 if l.return_type() != r.return_type() {
678 None
679 } else if expr_is_left {
680 Some(l.index())
681 } else {
682 Some(r.index())
683 }
684 })
685 .collect_vec();
686 if expr
687 .collect_input_refs(col_num)
688 .ones()
689 .any(|index| !eq_indices.contains(&index))
690 {
691 return None;
693 }
694 let other_side_mapping = if expr_is_left {
697 eq_condition.eq_indexes_typed().into_iter().collect()
698 } else {
699 eq_condition
700 .eq_indexes_typed()
701 .into_iter()
702 .map(|(x, y)| (y, x))
703 .collect()
704 };
705 struct InputRefsRewriter {
706 mapping: HashMap<InputRef, InputRef>,
707 }
708 impl ExprRewriter for InputRefsRewriter {
709 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
710 self.mapping[&input_ref].clone().into()
711 }
712 }
713 Some(
714 InputRefsRewriter {
715 mapping: other_side_mapping,
716 }
717 .rewrite_expr(expr.clone()),
718 )
719}
720
721struct LookupJoinPredicateRewriter {
723 offset: usize,
724 mapping: Vec<usize>,
725}
726impl ExprRewriter for LookupJoinPredicateRewriter {
727 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
728 if input_ref.index() < self.offset {
729 input_ref.into()
730 } else {
731 InputRef::new(
732 self.mapping[input_ref.index() - self.offset] + self.offset,
733 input_ref.return_type(),
734 )
735 .into()
736 }
737 }
738}
739
740struct LookupJoinScanPredicateRewriter {
742 offset: usize,
743}
744impl ExprRewriter for LookupJoinScanPredicateRewriter {
745 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
746 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
747 }
748}
749
750impl PredicatePushdown for LogicalJoin {
751 fn predicate_pushdown(
775 &self,
776 predicate: Condition,
777 ctx: &mut PredicatePushdownContext,
778 ) -> PlanRef {
779 let mut predicate = {
781 let mut mapping = self.core.o2i_col_mapping();
782 predicate.rewrite_expr(&mut mapping)
783 };
784
785 let left_col_num = self.left().schema().len();
786 let right_col_num = self.right().schema().len();
787 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
788
789 let push_down_temporal_predicate = !self.should_be_temporal_join();
790
791 let (left_from_filter, right_from_filter, on) = push_down_into_join(
792 &mut predicate,
793 left_col_num,
794 right_col_num,
795 join_type,
796 push_down_temporal_predicate,
797 );
798
799 let mut new_on = self.on().clone().and(on);
800 let (left_from_on, right_from_on) = push_down_join_condition(
801 &mut new_on,
802 left_col_num,
803 right_col_num,
804 join_type,
805 push_down_temporal_predicate,
806 );
807
808 let left_predicate = left_from_filter.and(left_from_on);
809 let right_predicate = right_from_filter.and(right_from_on);
810
811 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
813
814 let right_from_left = if matches!(
816 join_type,
817 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
818 ) {
819 Condition {
820 conjunctions: left_predicate
821 .conjunctions
822 .iter()
823 .filter_map(|expr| {
824 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
825 })
826 .collect(),
827 }
828 } else {
829 Condition::true_cond()
830 };
831
832 let left_from_right = if matches!(
834 join_type,
835 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
836 ) {
837 Condition {
838 conjunctions: right_predicate
839 .conjunctions
840 .iter()
841 .filter_map(|expr| {
842 derive_predicate_from_eq_condition(
843 expr,
844 &eq_condition,
845 right_col_num,
846 false,
847 )
848 })
849 .collect(),
850 }
851 } else {
852 Condition::true_cond()
853 };
854
855 let left_predicate = left_predicate.and(left_from_right);
856 let right_predicate = right_predicate.and(right_from_left);
857
858 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
859 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
860 let new_join = LogicalJoin::with_output_indices(
861 new_left,
862 new_right,
863 join_type,
864 new_on,
865 self.output_indices().clone(),
866 );
867
868 let mut mapping = self.core.i2o_col_mapping();
869 predicate = predicate.rewrite_expr(&mut mapping);
870 LogicalFilter::create(new_join.into(), predicate)
871 }
872}
873
874impl LogicalJoin {
875 fn get_stream_input_for_hash_join(
876 &self,
877 predicate: &EqJoinPredicate,
878 ctx: &mut ToStreamContext,
879 ) -> Result<(PlanRef, PlanRef)> {
880 use super::stream::prelude::*;
881
882 let mut right = self.right().to_stream_with_dist_required(
883 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
884 ctx,
885 )?;
886 let mut left = self.left();
887
888 let r2l = predicate.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
889 let l2r = predicate.l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
890
891 let right_dist = right.distribution();
892 match right_dist {
893 Distribution::HashShard(_) => {
894 let left_dist = r2l
895 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
896 left = left.to_stream_with_dist_required(&left_dist, ctx)?;
897 }
898 Distribution::UpstreamHashShard(_, _) => {
899 left = left.to_stream_with_dist_required(
900 &RequiredDist::shard_by_key(
901 self.left().schema().len(),
902 &predicate.left_eq_indexes(),
903 ),
904 ctx,
905 )?;
906 let left_dist = left.distribution();
907 match left_dist {
908 Distribution::HashShard(_) => {
909 let right_dist = l2r.rewrite_required_distribution(
910 &RequiredDist::PhysicalDist(left_dist.clone()),
911 );
912 right = right_dist.enforce_if_not_satisfies(right, &Order::any())?
913 }
914 Distribution::UpstreamHashShard(_, _) => {
915 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
916 .enforce_if_not_satisfies(left, &Order::any())?;
917 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
918 .enforce_if_not_satisfies(right, &Order::any())?;
919 }
920 _ => unreachable!(),
921 }
922 }
923 _ => unreachable!(),
924 }
925 Ok((left, right))
926 }
927
928 fn to_stream_hash_join(
929 &self,
930 predicate: EqJoinPredicate,
931 ctx: &mut ToStreamContext,
932 ) -> Result<PlanRef> {
933 use super::stream::prelude::*;
934
935 assert!(predicate.has_eq());
936 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
937
938 let logical_join = self.clone_with_left_right(left, right);
939
940 let stream_hash_join = StreamHashJoin::new(logical_join.core.clone(), predicate.clone());
946 let pull_filter = self.join_type() == JoinType::Inner
947 && stream_hash_join.eq_join_predicate().has_non_eq()
948 && stream_hash_join.inequality_pairs().is_empty();
949 if pull_filter {
950 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
951
952 let logical_join = logical_join.clone_with_output_indices(default_indices.clone());
954 let eq_cond = EqJoinPredicate::new(
955 Condition::true_cond(),
956 predicate.eq_keys().to_vec(),
957 self.left().schema().len(),
958 self.right().schema().len(),
959 );
960 let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
961 let hash_join = StreamHashJoin::new(logical_join.core, eq_cond).into();
962 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
963 let plan = StreamFilter::new(logical_filter).into();
964 if self.output_indices() != &default_indices {
965 let logical_project = generic::Project::with_mapping(
966 plan,
967 ColIndexMapping::with_remaining_columns(
968 self.output_indices(),
969 self.internal_column_num(),
970 ),
971 );
972 Ok(StreamProject::new(logical_project).into())
973 } else {
974 Ok(plan)
975 }
976 } else {
977 Ok(stream_hash_join.into())
978 }
979 }
980
981 fn should_be_temporal_join(&self) -> bool {
982 let right = self.right();
983 if let Some(logical_scan) = right.as_logical_scan() {
984 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
985 } else {
986 false
987 }
988 }
989
990 fn to_stream_temporal_join_with_index_selection(
991 &self,
992 predicate: EqJoinPredicate,
993 ctx: &mut ToStreamContext,
994 ) -> Result<StreamTemporalJoin> {
995 let right = self.right();
997 let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
999
1000 let mut result_plan = self.to_stream_temporal_join(predicate.clone(), ctx);
1002 if let Ok(temporal_join) = &result_plan
1004 && temporal_join.eq_join_predicate().eq_indexes().len()
1005 == logical_scan.primary_key().len()
1006 {
1007 return result_plan;
1008 }
1009 let indexes = logical_scan.indexes();
1010 for index in indexes {
1011 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1013 let index_scan: PlanRef = index_scan.into();
1014 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1015 if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1016 match &result_plan {
1017 Err(_) => result_plan = Ok(temporal_join),
1018 Ok(prev_temporal_join) => {
1019 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1021 < temporal_join.eq_join_predicate().eq_indexes().len()
1022 {
1023 result_plan = Ok(temporal_join)
1024 }
1025 }
1026 }
1027 }
1028 }
1029 }
1030
1031 result_plan
1032 }
1033
1034 fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1035 let Some(logical_scan) = right.as_logical_scan() else {
1036 return Err(RwError::from(ErrorCode::NotSupported(
1037 "Temporal join requires a table scan as its lookup table".into(),
1038 "Please provide a table scan".into(),
1039 )));
1040 };
1041
1042 if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1043 return Err(RwError::from(ErrorCode::NotSupported(
1044 "Temporal join requires a table defined as temporal table".into(),
1045 "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1046 )));
1047 }
1048 Ok(logical_scan)
1049 }
1050
1051 fn temporal_join_scan_predicate_pull_up(
1052 logical_scan: &LogicalScan,
1053 predicate: EqJoinPredicate,
1054 output_indices: &[usize],
1055 left_schema_len: usize,
1056 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1057 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1059 let o2r = if let Some(project_expr) = project_expr {
1061 project_expr
1062 .into_iter()
1063 .map(|x| x.as_input_ref().unwrap().index)
1064 .collect_vec()
1065 } else {
1066 (0..logical_scan.output_col_idx().len()).collect_vec()
1067 };
1068 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1069 offset: left_schema_len,
1070 mapping: o2r.clone(),
1071 };
1072
1073 let new_eq_cond = predicate
1074 .eq_cond()
1075 .rewrite_expr(&mut join_predicate_rewriter);
1076
1077 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1078 offset: left_schema_len,
1079 };
1080
1081 let new_other_cond = predicate
1082 .other_cond()
1083 .clone()
1084 .rewrite_expr(&mut join_predicate_rewriter)
1085 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1086
1087 let new_join_on = new_eq_cond.and(new_other_cond);
1088
1089 let new_predicate = EqJoinPredicate::create(
1090 left_schema_len,
1091 new_scan.schema().len(),
1092 new_join_on.clone(),
1093 );
1094
1095 let new_join_output_indices = output_indices
1098 .iter()
1099 .map(|&x| {
1100 if x < left_schema_len {
1101 x
1102 } else {
1103 o2r[x - left_schema_len] + left_schema_len
1104 }
1105 })
1106 .collect_vec();
1107 let new_stream_table_scan =
1109 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1110 Ok((
1111 new_stream_table_scan,
1112 new_predicate,
1113 new_join_on,
1114 new_join_output_indices,
1115 ))
1116 }
1117
1118 fn to_stream_temporal_join(
1119 &self,
1120 predicate: EqJoinPredicate,
1121 ctx: &mut ToStreamContext,
1122 ) -> Result<StreamTemporalJoin> {
1123 use super::stream::prelude::*;
1124
1125 assert!(predicate.has_eq());
1126
1127 let right = self.right();
1128
1129 let logical_scan = Self::check_temporal_rhs(&right)?;
1130
1131 let table_desc = logical_scan.table_desc();
1132 let output_column_ids = logical_scan.output_column_ids();
1133
1134 let order_col_ids = table_desc.order_column_ids();
1137 let order_key = table_desc.order_column_indices();
1138 let dist_key = table_desc.distribution_key.clone();
1139
1140 let mut dist_key_in_order_key_pos = vec![];
1141 for d in dist_key {
1142 let pos = order_key
1143 .iter()
1144 .position(|&x| x == d)
1145 .expect("dist_key must in order_key");
1146 dist_key_in_order_key_pos.push(pos);
1147 }
1148 let shortest_prefix_len = dist_key_in_order_key_pos
1150 .iter()
1151 .max()
1152 .map_or(0, |pos| pos + 1);
1153
1154 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1156 for order_col_id in order_col_ids {
1157 let mut found = false;
1158 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1159 if order_col_id == output_column_ids[eq_idx] {
1160 reorder_idx.push(i);
1161 found = true;
1162 break;
1163 }
1164 }
1165 if !found {
1166 break;
1167 }
1168 }
1169 if reorder_idx.len() < shortest_prefix_len {
1170 return Err(RwError::from(ErrorCode::NotSupported(
1172 "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1173 "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1174 )));
1175 }
1176 let lookup_prefix_len = reorder_idx.len();
1177 let predicate = predicate.reorder(&reorder_idx);
1178
1179 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1180 RequiredDist::single()
1181 } else {
1182 let left_eq_indexes = predicate.left_eq_indexes();
1183 let left_dist_key = dist_key_in_order_key_pos
1184 .iter()
1185 .map(|pos| left_eq_indexes[*pos])
1186 .collect_vec();
1187
1188 RequiredDist::hash_shard(&left_dist_key)
1189 };
1190
1191 let left = self.left().to_stream(ctx)?;
1192 let left = required_dist.enforce(left, &Order::any());
1194
1195 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1196 Self::temporal_join_scan_predicate_pull_up(
1197 logical_scan,
1198 predicate,
1199 self.output_indices(),
1200 self.left().schema().len(),
1201 )?;
1202
1203 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1204 if !new_predicate.has_eq() {
1205 return Err(RwError::from(ErrorCode::NotSupported(
1206 "Temporal join requires a non trivial join condition".into(),
1207 "Please remove the false condition of the join".into(),
1208 )));
1209 }
1210
1211 let new_logical_join = generic::Join::new(
1213 left,
1214 right,
1215 new_join_on,
1216 self.join_type(),
1217 new_join_output_indices,
1218 );
1219
1220 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1221
1222 Ok(StreamTemporalJoin::new(
1223 new_logical_join,
1224 new_predicate,
1225 false,
1226 ))
1227 }
1228
1229 fn to_stream_nested_loop_temporal_join(
1230 &self,
1231 predicate: EqJoinPredicate,
1232 ctx: &mut ToStreamContext,
1233 ) -> Result<StreamTemporalJoin> {
1234 use super::stream::prelude::*;
1235 assert!(!predicate.has_eq());
1236
1237 let left = self.left().to_stream_with_dist_required(
1238 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1239 ctx,
1240 )?;
1241 assert!(left.as_stream_exchange().is_some());
1242
1243 if self.join_type() != JoinType::Inner {
1244 return Err(RwError::from(ErrorCode::NotSupported(
1245 "Temporal join requires an inner join".into(),
1246 "Please use an inner join".into(),
1247 )));
1248 }
1249
1250 if !left.append_only() {
1251 return Err(RwError::from(ErrorCode::NotSupported(
1252 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1253 "Please ensure the left hash side is append only".into(),
1254 )));
1255 }
1256
1257 let right = self.right();
1258 let logical_scan = Self::check_temporal_rhs(&right)?;
1259
1260 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1261 Self::temporal_join_scan_predicate_pull_up(
1262 logical_scan,
1263 predicate,
1264 self.output_indices(),
1265 self.left().schema().len(),
1266 )?;
1267
1268 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1269
1270 let new_logical_join = generic::Join::new(
1272 left,
1273 right,
1274 new_join_on,
1275 self.join_type(),
1276 new_join_output_indices,
1277 );
1278
1279 Ok(StreamTemporalJoin::new(
1280 new_logical_join,
1281 new_predicate,
1282 true,
1283 ))
1284 }
1285
1286 fn to_stream_dynamic_filter(
1287 &self,
1288 predicate: Condition,
1289 ctx: &mut ToStreamContext,
1290 ) -> Result<Option<PlanRef>> {
1291 use super::stream::prelude::*;
1292
1293 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1299 return Ok(None);
1300 }
1301
1302 if !self.right().max_one_row() {
1304 return Ok(None);
1305 }
1306 if self.right().schema().len() != 1 {
1307 return Ok(None);
1308 }
1309
1310 if predicate.conjunctions.len() > 1 {
1312 return Ok(None);
1313 }
1314 let expr: ExprImpl = predicate.into();
1315 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1316 Some(v) => v,
1317 None => return Ok(None),
1318 };
1319
1320 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1321 && right_ref.index == self.left().schema().len() ;
1322 if !condition_cross_inputs {
1323 return Ok(None);
1325 }
1326
1327 if self.left().schema().fields()[left_ref.index].data_type
1329 != self.right().schema().fields()[0].data_type
1330 {
1331 return Ok(None);
1332 }
1333
1334 let all_output_from_left = self
1336 .output_indices()
1337 .iter()
1338 .all(|i| *i < self.left().schema().len());
1339 if !all_output_from_left {
1340 return Ok(None);
1341 }
1342
1343 let left = self.left().to_stream(ctx)?;
1344 let right = self.right().to_stream_with_dist_required(
1345 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1346 ctx,
1347 )?;
1348
1349 assert!(right.as_stream_exchange().is_some());
1350 assert_eq!(
1351 *right.inputs().iter().exactly_one().unwrap().distribution(),
1352 Distribution::Single
1353 );
1354
1355 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1356 let plan = StreamDynamicFilter::new(core).into();
1357 if self
1359 .output_indices()
1360 .iter()
1361 .copied()
1362 .ne(0..self.left().schema().len())
1363 {
1364 let logical_project = generic::Project::with_mapping(
1367 plan,
1368 ColIndexMapping::with_remaining_columns(
1369 self.output_indices(),
1370 self.left().schema().len(),
1371 ),
1372 );
1373 Ok(Some(StreamProject::new(logical_project).into()))
1374 } else {
1375 Ok(Some(plan))
1376 }
1377 }
1378
1379 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
1380 let predicate = EqJoinPredicate::create(
1381 self.left().schema().len(),
1382 self.right().schema().len(),
1383 self.on().clone(),
1384 );
1385 assert!(predicate.has_eq());
1386
1387 let mut logical_join = self.core.clone();
1388 logical_join.left = logical_join.left.to_batch()?;
1389 logical_join.right = logical_join.right.to_batch()?;
1390
1391 Ok(self
1392 .to_batch_lookup_join(predicate, logical_join)?
1393 .expect("Fail to convert to lookup join")
1394 .into())
1395 }
1396
1397 fn to_stream_asof_join(
1398 &self,
1399 predicate: EqJoinPredicate,
1400 ctx: &mut ToStreamContext,
1401 ) -> Result<StreamAsOfJoin> {
1402 use super::stream::prelude::*;
1403
1404 if predicate.eq_keys().is_empty() {
1405 return Err(ErrorCode::InvalidInputSyntax(
1406 "AsOf join requires at least 1 equal condition".to_owned(),
1407 )
1408 .into());
1409 }
1410
1411 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1412 let left_len = left.schema().len();
1413 let logical_join = self.clone_with_left_right(left, right);
1414
1415 let inequality_desc =
1416 Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1417
1418 Ok(StreamAsOfJoin::new(
1419 logical_join.core.clone(),
1420 predicate,
1421 inequality_desc,
1422 ))
1423 }
1424
1425 fn to_batch_hash_join(
1427 &self,
1428 logical_join: generic::Join<PlanRef>,
1429 predicate: EqJoinPredicate,
1430 ) -> Result<PlanRef> {
1431 use super::batch::prelude::*;
1432
1433 let left_schema_len = logical_join.left.schema().len();
1434 let asof_desc = self
1435 .is_asof_join()
1436 .then(|| {
1437 Self::get_inequality_desc_from_predicate(
1438 predicate.other_cond().clone(),
1439 left_schema_len,
1440 )
1441 })
1442 .transpose()?;
1443
1444 let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1445 Ok(batch_join.into())
1446 }
1447
1448 pub fn get_inequality_desc_from_predicate(
1449 predicate: Condition,
1450 left_input_len: usize,
1451 ) -> Result<AsOfJoinDesc> {
1452 let expr: ExprImpl = predicate.into();
1453 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1454 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1455 {
1456 Ok(AsOfJoinDesc {
1457 left_idx: left_input_ref.index() as u32,
1458 right_idx: (right_input_ref.index() - left_input_len) as u32,
1459 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1460 })
1461 } else {
1462 bail!("inequal condition from the same side should be push down in optimizer");
1463 }
1464 } else {
1465 Err(ErrorCode::InvalidInputSyntax(
1466 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1467 )
1468 .into())
1469 }
1470 }
1471
1472 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1473 match expr_type {
1474 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1475 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1476 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1477 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1478 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1479 "Invalid comparison type: {}",
1480 expr_type.as_str_name()
1481 ))
1482 .into()),
1483 }
1484 }
1485}
1486
1487impl ToBatch for LogicalJoin {
1488 fn to_batch(&self) -> Result<PlanRef> {
1489 let predicate = EqJoinPredicate::create(
1490 self.left().schema().len(),
1491 self.right().schema().len(),
1492 self.on().clone(),
1493 );
1494
1495 let mut logical_join = self.core.clone();
1496 logical_join.left = logical_join.left.to_batch()?;
1497 logical_join.right = logical_join.right.to_batch()?;
1498
1499 let ctx = self.base.ctx();
1500 let config = ctx.session_ctx().config();
1501
1502 if predicate.has_eq() {
1503 if !predicate.eq_keys_are_type_aligned() {
1504 return Err(ErrorCode::InternalError(format!(
1505 "Join eq keys are not aligned for predicate: {predicate:?}"
1506 ))
1507 .into());
1508 }
1509 if config.batch_enable_lookup_join() {
1510 if let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1511 predicate.clone(),
1512 logical_join.clone(),
1513 )? {
1514 return Ok(lookup_join.into());
1515 }
1516 }
1517 self.to_batch_hash_join(logical_join, predicate)
1518 } else if self.is_asof_join() {
1519 return Err(ErrorCode::InvalidInputSyntax(
1520 "AsOf join requires at least 1 equal condition".to_owned(),
1521 )
1522 .into());
1523 } else {
1524 Ok(BatchNestedLoopJoin::new(logical_join).into())
1526 }
1527 }
1528}
1529
1530impl ToStream for LogicalJoin {
1531 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1532 if self
1533 .on()
1534 .conjunctions
1535 .iter()
1536 .any(|cond| cond.count_nows() > 0)
1537 {
1538 return Err(ErrorCode::NotSupported(
1539 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1540 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1541 }
1542
1543 let predicate = EqJoinPredicate::create(
1544 self.left().schema().len(),
1545 self.right().schema().len(),
1546 self.on().clone(),
1547 );
1548
1549 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1550 self.to_stream_asof_join(predicate, ctx).map(|x| x.into())
1551 } else if predicate.has_eq() {
1552 if !predicate.eq_keys_are_type_aligned() {
1553 return Err(ErrorCode::InternalError(format!(
1554 "Join eq keys are not aligned for predicate: {predicate:?}"
1555 ))
1556 .into());
1557 }
1558
1559 if self.should_be_temporal_join() {
1560 self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1561 .map(|x| x.into())
1562 } else {
1563 self.to_stream_hash_join(predicate, ctx)
1564 }
1565 } else if self.should_be_temporal_join() {
1566 self.to_stream_nested_loop_temporal_join(predicate, ctx)
1567 .map(|x| x.into())
1568 } else if let Some(dynamic_filter) =
1569 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1570 {
1571 Ok(dynamic_filter)
1572 } else {
1573 Err(RwError::from(ErrorCode::NotSupported(
1574 "streaming nested-loop join".to_owned(),
1575 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1576 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1577 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1578 )))
1579 }
1580 }
1581
1582 fn logical_rewrite_for_stream(
1583 &self,
1584 ctx: &mut RewriteStreamContext,
1585 ) -> Result<(PlanRef, ColIndexMapping)> {
1586 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1587 let left_len = left.schema().len();
1588 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1589 let (join, out_col_change) = self.rewrite_with_left_right(
1590 left.clone(),
1591 left_col_change,
1592 right.clone(),
1593 right_col_change,
1594 );
1595
1596 let mapping = ColIndexMapping::with_remaining_columns(
1597 join.output_indices(),
1598 join.internal_column_num(),
1599 );
1600
1601 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1602 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1603
1604 let mut left_to_add = left
1606 .expect_stream_key()
1607 .iter()
1608 .cloned()
1609 .filter(|i| l2o.try_map(*i).is_none())
1610 .collect_vec();
1611
1612 let mut right_to_add = right
1613 .expect_stream_key()
1614 .iter()
1615 .filter(|&&i| r2o.try_map(i).is_none())
1616 .map(|&i| i + left_len)
1617 .collect_vec();
1618
1619 let right_len = right.schema().len();
1622 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1623
1624 let either_or_both = self.core.add_which_join_key_to_pk();
1625
1626 for (lk, rk) in eq_predicate.eq_indexes() {
1627 match either_or_both {
1628 EitherOrBoth::Left(_) => {
1629 if l2o.try_map(lk).is_none() {
1630 left_to_add.push(lk);
1631 }
1632 }
1633 EitherOrBoth::Right(_) => {
1634 if r2o.try_map(rk).is_none() {
1635 right_to_add.push(rk + left_len)
1636 }
1637 }
1638 EitherOrBoth::Both(_, _) => {
1639 if l2o.try_map(lk).is_none() {
1640 left_to_add.push(lk);
1641 }
1642 if r2o.try_map(rk).is_none() {
1643 right_to_add.push(rk + left_len)
1644 }
1645 }
1646 };
1647 }
1648 let left_to_add = left_to_add.into_iter().unique();
1649 let right_to_add = right_to_add.into_iter().unique();
1650 let mut new_output_indices = join.output_indices().clone();
1653 if !join.is_right_join() {
1654 new_output_indices.extend(left_to_add);
1655 }
1656 if !join.is_left_join() {
1657 new_output_indices.extend(right_to_add);
1658 }
1659
1660 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1661
1662 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1663 let l2o = join_with_pk
1666 .core
1667 .l2i_col_mapping()
1668 .composite(&join_with_pk.core.i2o_col_mapping());
1669 let r2o = join_with_pk
1670 .core
1671 .r2i_col_mapping()
1672 .composite(&join_with_pk.core.i2o_col_mapping());
1673 let left_right_stream_keys = join_with_pk
1674 .left()
1675 .expect_stream_key()
1676 .iter()
1677 .map(|i| l2o.map(*i))
1678 .chain(
1679 join_with_pk
1680 .right()
1681 .expect_stream_key()
1682 .iter()
1683 .map(|i| r2o.map(*i)),
1684 )
1685 .collect_vec();
1686 let plan: PlanRef = join_with_pk.into();
1687 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1688 } else {
1689 join_with_pk.into()
1690 };
1691
1692 Ok((plan, out_col_change))
1694 }
1695}
1696
1697#[cfg(test)]
1698mod tests {
1699
1700 use std::collections::HashSet;
1701
1702 use risingwave_common::catalog::{Field, Schema};
1703 use risingwave_common::types::{DataType, Datum};
1704 use risingwave_pb::expr::expr_node::Type;
1705
1706 use super::*;
1707 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1708 use crate::optimizer::optimizer_context::OptimizerContext;
1709 use crate::optimizer::plan_node::LogicalValues;
1710 use crate::optimizer::property::FunctionalDependency;
1711
1712 #[tokio::test]
1726 async fn test_prune_join() {
1727 let ty = DataType::Int32;
1728 let ctx = OptimizerContext::mock().await;
1729 let fields: Vec<Field> = (1..7)
1730 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1731 .collect();
1732 let left = LogicalValues::new(
1733 vec![],
1734 Schema {
1735 fields: fields[0..3].to_vec(),
1736 },
1737 ctx.clone(),
1738 );
1739 let right = LogicalValues::new(
1740 vec![],
1741 Schema {
1742 fields: fields[3..6].to_vec(),
1743 },
1744 ctx,
1745 );
1746 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1747 FunctionCall::new(
1748 Type::Equal,
1749 vec![
1750 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1751 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1752 ],
1753 )
1754 .unwrap(),
1755 ));
1756 let join_type = JoinType::Inner;
1757 let join: PlanRef = LogicalJoin::new(
1758 left.into(),
1759 right.into(),
1760 join_type,
1761 Condition::with_expr(on),
1762 )
1763 .into();
1764
1765 let required_cols = vec![2, 3];
1767 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1768
1769 let join = plan.as_logical_join().unwrap();
1771 assert_eq!(join.schema().fields().len(), 2);
1772 assert_eq!(join.schema().fields()[0], fields[2]);
1773 assert_eq!(join.schema().fields()[1], fields[3]);
1774
1775 let expr: ExprImpl = join.on().clone().into();
1776 let call = expr.as_function_call().unwrap();
1777 assert_eq_input_ref!(&call.inputs()[0], 0);
1778 assert_eq_input_ref!(&call.inputs()[1], 2);
1779
1780 let left = join.left();
1781 let left = left.as_logical_values().unwrap();
1782 assert_eq!(left.schema().fields(), &fields[1..3]);
1783 let right = join.right();
1784 let right = right.as_logical_values().unwrap();
1785 assert_eq!(right.schema().fields(), &fields[3..4]);
1786 }
1787
1788 #[tokio::test]
1790 async fn test_prune_semi_join() {
1791 let ty = DataType::Int32;
1792 let ctx = OptimizerContext::mock().await;
1793 let fields: Vec<Field> = (1..7)
1794 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1795 .collect();
1796 let left = LogicalValues::new(
1797 vec![],
1798 Schema {
1799 fields: fields[0..3].to_vec(),
1800 },
1801 ctx.clone(),
1802 );
1803 let right = LogicalValues::new(
1804 vec![],
1805 Schema {
1806 fields: fields[3..6].to_vec(),
1807 },
1808 ctx,
1809 );
1810 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1811 FunctionCall::new(
1812 Type::Equal,
1813 vec![
1814 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1815 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1816 ],
1817 )
1818 .unwrap(),
1819 ));
1820 for join_type in [
1821 JoinType::LeftSemi,
1822 JoinType::RightSemi,
1823 JoinType::LeftAnti,
1824 JoinType::RightAnti,
1825 ] {
1826 let join = LogicalJoin::new(
1827 left.clone().into(),
1828 right.clone().into(),
1829 join_type,
1830 Condition::with_expr(on.clone()),
1831 );
1832
1833 let offset = if join.is_right_join() { 3 } else { 0 };
1834 let join: PlanRef = join.into();
1835 let required_cols = vec![0];
1837 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1839 let as_plan = plan.as_logical_join().unwrap();
1840 assert_eq!(as_plan.schema().fields().len(), 1);
1842 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1843
1844 let required_cols = vec![0, 1, 2];
1846 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1848 let as_plan = plan.as_logical_join().unwrap();
1849 assert_eq!(as_plan.schema().fields().len(), 3);
1851 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1852 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1853 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1854 }
1855 }
1856
1857 #[tokio::test]
1870 async fn test_prune_join_no_project() {
1871 let ty = DataType::Int32;
1872 let ctx = OptimizerContext::mock().await;
1873 let fields: Vec<Field> = (1..7)
1874 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1875 .collect();
1876 let left = LogicalValues::new(
1877 vec![],
1878 Schema {
1879 fields: fields[0..3].to_vec(),
1880 },
1881 ctx.clone(),
1882 );
1883 let right = LogicalValues::new(
1884 vec![],
1885 Schema {
1886 fields: fields[3..6].to_vec(),
1887 },
1888 ctx,
1889 );
1890 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1891 FunctionCall::new(
1892 Type::Equal,
1893 vec![
1894 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1895 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1896 ],
1897 )
1898 .unwrap(),
1899 ));
1900 let join_type = JoinType::Inner;
1901 let join: PlanRef = LogicalJoin::new(
1902 left.into(),
1903 right.into(),
1904 join_type,
1905 Condition::with_expr(on),
1906 )
1907 .into();
1908
1909 let required_cols = vec![1, 3];
1911 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1912
1913 let join = plan.as_logical_join().unwrap();
1915 assert_eq!(join.schema().fields().len(), 2);
1916 assert_eq!(join.schema().fields()[0], fields[1]);
1917 assert_eq!(join.schema().fields()[1], fields[3]);
1918
1919 let expr: ExprImpl = join.on().clone().into();
1920 let call = expr.as_function_call().unwrap();
1921 assert_eq_input_ref!(&call.inputs()[0], 0);
1922 assert_eq_input_ref!(&call.inputs()[1], 1);
1923
1924 let left = join.left();
1925 let left = left.as_logical_values().unwrap();
1926 assert_eq!(left.schema().fields(), &fields[1..2]);
1927 let right = join.right();
1928 let right = right.as_logical_values().unwrap();
1929 assert_eq!(right.schema().fields(), &fields[3..4]);
1930 }
1931
1932 #[tokio::test]
1946 async fn test_join_to_batch() {
1947 let ctx = OptimizerContext::mock().await;
1948 let fields: Vec<Field> = (1..7)
1949 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1950 .collect();
1951 let left = LogicalValues::new(
1952 vec![],
1953 Schema {
1954 fields: fields[0..3].to_vec(),
1955 },
1956 ctx.clone(),
1957 );
1958 let right = LogicalValues::new(
1959 vec![],
1960 Schema {
1961 fields: fields[3..6].to_vec(),
1962 },
1963 ctx,
1964 );
1965
1966 fn input_ref(i: usize) -> ExprImpl {
1967 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1968 }
1969 let eq_cond = ExprImpl::FunctionCall(Box::new(
1970 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1971 ));
1972 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1973 FunctionCall::new(
1974 Type::Equal,
1975 vec![
1976 input_ref(2),
1977 ExprImpl::Literal(Box::new(Literal::new(
1978 Datum::Some(42_i32.into()),
1979 DataType::Int32,
1980 ))),
1981 ],
1982 )
1983 .unwrap(),
1984 ));
1985 let on_cond = ExprImpl::FunctionCall(Box::new(
1987 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
1988 ));
1989
1990 let join_type = JoinType::Inner;
1991 let logical_join = LogicalJoin::new(
1992 left.into(),
1993 right.into(),
1994 join_type,
1995 Condition::with_expr(on_cond),
1996 );
1997
1998 let result = logical_join.to_batch().unwrap();
2000
2001 let hash_join = result.as_batch_hash_join().unwrap();
2003 assert_eq!(
2004 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2005 eq_cond
2006 );
2007 assert_eq!(
2008 *hash_join
2009 .eq_join_predicate()
2010 .non_eq_cond()
2011 .conjunctions
2012 .first()
2013 .unwrap(),
2014 non_eq_cond
2015 );
2016 }
2017
2018 #[tokio::test]
2031 #[ignore] async fn test_join_to_stream() {
2034 }
2102 #[tokio::test]
2116 async fn test_join_column_prune_with_order_required() {
2117 let ty = DataType::Int32;
2118 let ctx = OptimizerContext::mock().await;
2119 let fields: Vec<Field> = (1..7)
2120 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2121 .collect();
2122 let left = LogicalValues::new(
2123 vec![],
2124 Schema {
2125 fields: fields[0..3].to_vec(),
2126 },
2127 ctx.clone(),
2128 );
2129 let right = LogicalValues::new(
2130 vec![],
2131 Schema {
2132 fields: fields[3..6].to_vec(),
2133 },
2134 ctx,
2135 );
2136 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2137 FunctionCall::new(
2138 Type::Equal,
2139 vec![
2140 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2141 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2142 ],
2143 )
2144 .unwrap(),
2145 ));
2146 let join_type = JoinType::Inner;
2147 let join: PlanRef = LogicalJoin::new(
2148 left.into(),
2149 right.into(),
2150 join_type,
2151 Condition::with_expr(on),
2152 )
2153 .into();
2154
2155 let required_cols = vec![3, 2];
2157 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2158
2159 let join = plan.as_logical_join().unwrap();
2161 assert_eq!(join.schema().fields().len(), 2);
2162 assert_eq!(join.schema().fields()[0], fields[3]);
2163 assert_eq!(join.schema().fields()[1], fields[2]);
2164
2165 let expr: ExprImpl = join.on().clone().into();
2166 let call = expr.as_function_call().unwrap();
2167 assert_eq_input_ref!(&call.inputs()[0], 0);
2168 assert_eq_input_ref!(&call.inputs()[1], 2);
2169
2170 let left = join.left();
2171 let left = left.as_logical_values().unwrap();
2172 assert_eq!(left.schema().fields(), &fields[1..3]);
2173 let right = join.right();
2174 let right = right.as_logical_values().unwrap();
2175 assert_eq!(right.schema().fields(), &fields[3..4]);
2176 }
2177
2178 #[tokio::test]
2179 async fn fd_derivation_inner_outer_join() {
2180 let ctx = OptimizerContext::mock().await;
2203 let left = {
2204 let fields: Vec<Field> = vec![
2205 Field::with_name(DataType::Int32, "l0"),
2206 Field::with_name(DataType::Int32, "l1"),
2207 ];
2208 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2209 values
2211 .base
2212 .functional_dependency_mut()
2213 .add_functional_dependency_by_column_indices(&[0], &[1]);
2214 values
2215 };
2216 let right = {
2217 let fields: Vec<Field> = vec![
2218 Field::with_name(DataType::Int32, "r0"),
2219 Field::with_name(DataType::Int32, "r1"),
2220 Field::with_name(DataType::Int32, "r2"),
2221 ];
2222 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2223 values
2225 .base
2226 .functional_dependency_mut()
2227 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2228 values
2229 };
2230 let on: ExprImpl = FunctionCall::new(
2232 Type::And,
2233 vec![
2234 FunctionCall::new(
2235 Type::Equal,
2236 vec![
2237 InputRef::new(0, DataType::Int32).into(),
2238 ExprImpl::literal_int(0),
2239 ],
2240 )
2241 .unwrap()
2242 .into(),
2243 FunctionCall::new(
2244 Type::Equal,
2245 vec![
2246 InputRef::new(1, DataType::Int32).into(),
2247 InputRef::new(3, DataType::Int32).into(),
2248 ],
2249 )
2250 .unwrap()
2251 .into(),
2252 ],
2253 )
2254 .unwrap()
2255 .into();
2256 let expected_fd_set = [
2257 (
2258 JoinType::Inner,
2259 [
2260 FunctionalDependency::with_indices(5, &[0], &[1]),
2262 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2264 FunctionalDependency::with_indices(5, &[], &[0]),
2266 FunctionalDependency::with_indices(5, &[1], &[3]),
2268 FunctionalDependency::with_indices(5, &[3], &[1]),
2269 ]
2270 .into_iter()
2271 .collect::<HashSet<_>>(),
2272 ),
2273 (JoinType::FullOuter, HashSet::new()),
2274 (
2275 JoinType::RightOuter,
2276 [
2277 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2279 ]
2280 .into_iter()
2281 .collect::<HashSet<_>>(),
2282 ),
2283 (
2284 JoinType::LeftOuter,
2285 [
2286 FunctionalDependency::with_indices(5, &[0], &[1]),
2288 ]
2289 .into_iter()
2290 .collect::<HashSet<_>>(),
2291 ),
2292 (
2293 JoinType::LeftSemi,
2294 [
2295 FunctionalDependency::with_indices(2, &[0], &[1]),
2297 ]
2298 .into_iter()
2299 .collect::<HashSet<_>>(),
2300 ),
2301 (
2302 JoinType::LeftAnti,
2303 [
2304 FunctionalDependency::with_indices(2, &[0], &[1]),
2306 ]
2307 .into_iter()
2308 .collect::<HashSet<_>>(),
2309 ),
2310 (
2311 JoinType::RightSemi,
2312 [
2313 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2315 ]
2316 .into_iter()
2317 .collect::<HashSet<_>>(),
2318 ),
2319 (
2320 JoinType::RightAnti,
2321 [
2322 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2324 ]
2325 .into_iter()
2326 .collect::<HashSet<_>>(),
2327 ),
2328 ];
2329
2330 for (join_type, expected_res) in expected_fd_set {
2331 let join = LogicalJoin::new(
2332 left.clone().into(),
2333 right.clone().into(),
2334 join_type,
2335 Condition::with_expr(on.clone()),
2336 );
2337 let fd_set = join
2338 .functional_dependency()
2339 .as_dependencies()
2340 .iter()
2341 .cloned()
2342 .collect::<HashSet<_>>();
2343 assert_eq!(fd_set, expected_res);
2344 }
2345 }
2346}