1use std::collections::HashMap;
16
17use fixedbitset::FixedBitSet;
18use itertools::{EitherOrBoth, Itertools};
19use pretty_xmlish::{Pretty, XmlNode};
20use risingwave_expr::bail;
21use risingwave_pb::expr::expr_node::PbType;
22use risingwave_pb::plan_common::{AsOfJoinDesc, JoinType, PbAsOfJoinInequalityType};
23use risingwave_pb::stream_plan::StreamScanType;
24use risingwave_sqlparser::ast::AsOf;
25
26use super::generic::{
27 GenericPlanNode, GenericPlanRef, push_down_into_join, push_down_join_condition,
28};
29use super::utils::{Distill, childless_record};
30use super::{
31 ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeBinary, PredicatePushdown,
32 StreamHashJoin, StreamProject, ToBatch, ToStream, generic,
33};
34use crate::error::{ErrorCode, Result, RwError};
35use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::DynamicFilter;
38use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin;
39use crate::optimizer::plan_node::utils::IndicesDisplay;
40use crate::optimizer::plan_node::{
41 BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate,
42 LogicalFilter, LogicalScan, PredicatePushdownContext, RewriteStreamContext,
43 StreamDynamicFilter, StreamFilter, StreamTableScan, StreamTemporalJoin, ToStreamContext,
44};
45use crate::optimizer::plan_visitor::LogicalCardinalityExt;
46use crate::optimizer::property::{Distribution, Order, RequiredDist};
47use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay};
48
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct LogicalJoin {
57 pub base: PlanBase<Logical>,
58 core: generic::Join<PlanRef>,
59}
60
61impl Distill for LogicalJoin {
62 fn distill<'a>(&self) -> XmlNode<'a> {
63 let verbose = self.base.ctx().is_explain_verbose();
64 let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
65 vec.push(("type", Pretty::debug(&self.join_type())));
66
67 let concat_schema = self.core.concat_schema();
68 let cond = Pretty::debug(&ConditionDisplay {
69 condition: self.on(),
70 input_schema: &concat_schema,
71 });
72 vec.push(("on", cond));
73
74 if verbose {
75 let data = IndicesDisplay::from_join(&self.core, &concat_schema);
76 vec.push(("output", data));
77 }
78
79 childless_record("LogicalJoin", vec)
80 }
81}
82
83impl LogicalJoin {
84 pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
85 let core = generic::Join::with_full_output(left, right, join_type, on);
86 Self::with_core(core)
87 }
88
89 pub(crate) fn with_output_indices(
90 left: PlanRef,
91 right: PlanRef,
92 join_type: JoinType,
93 on: Condition,
94 output_indices: Vec<usize>,
95 ) -> Self {
96 let core = generic::Join::new(left, right, on, join_type, output_indices);
97 Self::with_core(core)
98 }
99
100 pub fn with_core(core: generic::Join<PlanRef>) -> Self {
101 let base = PlanBase::new_logical_with_core(&core);
102 LogicalJoin { base, core }
103 }
104
105 pub fn create(
106 left: PlanRef,
107 right: PlanRef,
108 join_type: JoinType,
109 on_clause: ExprImpl,
110 ) -> PlanRef {
111 Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
112 }
113
114 pub fn internal_column_num(&self) -> usize {
115 self.core.internal_column_num()
116 }
117
118 pub fn i2l_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
119 self.core.i2l_col_mapping_ignore_join_type()
120 }
121
122 pub fn i2r_col_mapping_ignore_join_type(&self) -> ColIndexMapping {
123 self.core.i2r_col_mapping_ignore_join_type()
124 }
125
126 pub fn on(&self) -> &Condition {
128 &self.core.on
129 }
130
131 pub fn input_idx_on_condition(&self) -> (Vec<usize>, Vec<usize>) {
133 let input_refs = self
134 .core
135 .on
136 .collect_input_refs(self.core.left.schema().len() + self.core.right.schema().len());
137 let index_group = input_refs
138 .ones()
139 .chunk_by(|i| *i < self.core.left.schema().len());
140 let left_index = index_group
141 .into_iter()
142 .next()
143 .map_or(vec![], |group| group.1.collect_vec());
144 let right_index = index_group.into_iter().next().map_or(vec![], |group| {
145 group
146 .1
147 .map(|i| i - self.core.left.schema().len())
148 .collect_vec()
149 });
150 (left_index, right_index)
151 }
152
153 pub fn join_type(&self) -> JoinType {
155 self.core.join_type
156 }
157
158 pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
160 self.core.eq_indexes()
161 }
162
163 pub fn output_indices(&self) -> &Vec<usize> {
165 &self.core.output_indices
166 }
167
168 pub fn clone_with_output_indices(&self, output_indices: Vec<usize>) -> Self {
170 Self::with_core(generic::Join {
171 output_indices,
172 ..self.core.clone()
173 })
174 }
175
176 pub fn clone_with_cond(&self, on: Condition) -> Self {
178 Self::with_core(generic::Join {
179 on,
180 ..self.core.clone()
181 })
182 }
183
184 pub fn is_left_join(&self) -> bool {
185 matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
186 }
187
188 pub fn is_right_join(&self) -> bool {
189 matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
190 }
191
192 pub fn is_full_out(&self) -> bool {
193 self.core.is_full_out()
194 }
195
196 pub fn is_asof_join(&self) -> bool {
197 self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter
198 }
199
200 pub fn output_indices_are_trivial(&self) -> bool {
201 self.output_indices() == &(0..self.internal_column_num()).collect_vec()
202 }
203
204 fn simplify_outer(predicate: &Condition, left_col_num: usize, join_type: JoinType) -> JoinType {
209 let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
210 JoinType::LeftOuter => (false, true),
211 JoinType::RightOuter => (true, false),
212 JoinType::FullOuter => (true, true),
213 _ => return join_type,
214 };
215
216 for expr in &predicate.conjunctions {
217 if let ExprImpl::FunctionCall(func) = expr {
218 match func.func_type() {
219 ExprType::Equal
220 | ExprType::NotEqual
221 | ExprType::LessThan
222 | ExprType::LessThanOrEqual
223 | ExprType::GreaterThan
224 | ExprType::GreaterThanOrEqual => {
225 for input in func.inputs() {
226 if let ExprImpl::InputRef(input) = input {
227 let idx = input.index;
228 if idx < left_col_num {
229 gen_null_in_left = false;
230 } else {
231 gen_null_in_right = false;
232 }
233 }
234 }
235 }
236 _ => {}
237 };
238 }
239 }
240
241 match (gen_null_in_left, gen_null_in_right) {
242 (true, true) => JoinType::FullOuter,
243 (true, false) => JoinType::RightOuter,
244 (false, true) => JoinType::LeftOuter,
245 (false, false) => JoinType::Inner,
246 }
247 }
248
249 fn to_batch_lookup_join_with_index_selection(
253 &self,
254 predicate: EqJoinPredicate,
255 logical_join: generic::Join<PlanRef>,
256 ) -> Result<Option<BatchLookupJoin>> {
257 match logical_join.join_type {
258 JoinType::Inner
259 | JoinType::LeftOuter
260 | JoinType::LeftSemi
261 | JoinType::LeftAnti
262 | JoinType::AsofInner
263 | JoinType::AsofLeftOuter => {}
264 _ => return Ok(None),
265 };
266
267 let right = self.right();
269 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
271 logical_scan
272 } else {
273 return Ok(None);
274 };
275
276 let mut result_plan = None;
277 if let Some(lookup_join) =
279 self.to_batch_lookup_join(predicate.clone(), logical_join.clone())?
280 {
281 result_plan = Some(lookup_join);
282 }
283
284 let indexes = logical_scan.indexes();
285 for index in indexes {
286 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
287 let index_scan: PlanRef = index_scan.into();
288 let that = self.clone_with_left_right(self.left(), index_scan.clone());
289 let mut new_logical_join = logical_join.clone();
290 new_logical_join.right = index_scan.to_batch().expect("index scan failed to batch");
291
292 if let Some(lookup_join) =
294 that.to_batch_lookup_join(predicate.clone(), new_logical_join)?
295 {
296 match &result_plan {
297 None => result_plan = Some(lookup_join),
298 Some(prev_lookup_join) => {
299 if prev_lookup_join.lookup_prefix_len()
301 < lookup_join.lookup_prefix_len()
302 {
303 result_plan = Some(lookup_join)
304 }
305 }
306 }
307 }
308 }
309 }
310
311 Ok(result_plan)
312 }
313
314 fn to_batch_lookup_join(
316 &self,
317 predicate: EqJoinPredicate,
318 logical_join: generic::Join<PlanRef>,
319 ) -> Result<Option<BatchLookupJoin>> {
320 match logical_join.join_type {
321 JoinType::Inner
322 | JoinType::LeftOuter
323 | JoinType::LeftSemi
324 | JoinType::LeftAnti
325 | JoinType::AsofInner
326 | JoinType::AsofLeftOuter => {}
327 _ => return Ok(None),
328 };
329
330 let right = self.right();
331 let logical_scan: &LogicalScan = if let Some(logical_scan) = right.as_logical_scan() {
333 logical_scan
334 } else {
335 return Ok(None);
336 };
337 let table_desc = logical_scan.table_desc().clone();
338 let output_column_ids = logical_scan.output_column_ids();
339
340 let order_col_ids = table_desc.order_column_ids();
343 let order_key = table_desc.order_column_indices();
344 let dist_key = table_desc.distribution_key.clone();
345 let mut dist_key_in_order_key_pos = vec![];
347 for d in dist_key {
348 let pos = order_key
349 .iter()
350 .position(|&x| x == d)
351 .expect("dist_key must in order_key");
352 dist_key_in_order_key_pos.push(pos);
353 }
354 let shortest_prefix_len = dist_key_in_order_key_pos
356 .iter()
357 .max()
358 .map_or(0, |pos| pos + 1);
359
360 if shortest_prefix_len == 0 {
362 return Ok(None);
363 }
364
365 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
367 for order_col_id in order_col_ids {
368 let mut found = false;
369 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
370 if order_col_id == output_column_ids[eq_idx] {
371 reorder_idx.push(i);
372 found = true;
373 break;
374 }
375 }
376 if !found {
377 break;
378 }
379 }
380 if reorder_idx.len() < shortest_prefix_len {
381 return Ok(None);
382 }
383 let lookup_prefix_len = reorder_idx.len();
384 let predicate = predicate.reorder(&reorder_idx);
385
386 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
388 let o2r = if let Some(project_expr) = project_expr {
390 project_expr
391 .into_iter()
392 .map(|x| x.as_input_ref().unwrap().index)
393 .collect_vec()
394 } else {
395 (0..logical_scan.output_col_idx().len()).collect_vec()
396 };
397 let left_schema_len = logical_join.left.schema().len();
398
399 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
400 offset: left_schema_len,
401 mapping: o2r.clone(),
402 };
403
404 let new_eq_cond = predicate
405 .eq_cond()
406 .rewrite_expr(&mut join_predicate_rewriter);
407
408 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
409 offset: left_schema_len,
410 };
411
412 let new_other_cond = predicate
413 .other_cond()
414 .clone()
415 .rewrite_expr(&mut join_predicate_rewriter)
416 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
417
418 let new_join_on = new_eq_cond.and(new_other_cond);
419 let new_predicate = EqJoinPredicate::create(
420 left_schema_len,
421 new_scan.schema().len(),
422 new_join_on.clone(),
423 );
424
425 if !new_predicate.has_eq() {
428 return Ok(None);
429 }
430
431 let new_join_output_indices = logical_join
434 .output_indices
435 .iter()
436 .map(|&x| {
437 if x < left_schema_len {
438 x
439 } else {
440 o2r[x - left_schema_len] + left_schema_len
441 }
442 })
443 .collect_vec();
444
445 let new_scan_output_column_ids = new_scan.output_column_ids();
446 let as_of = new_scan.as_of.clone();
447
448 let new_logical_join = generic::Join::new(
450 logical_join.left,
451 new_scan.into(),
452 new_join_on,
453 logical_join.join_type,
454 new_join_output_indices,
455 );
456
457 let asof_desc = self
458 .is_asof_join()
459 .then(|| {
460 Self::get_inequality_desc_from_predicate(
461 predicate.other_cond().clone(),
462 left_schema_len,
463 )
464 })
465 .transpose()?;
466
467 Ok(Some(BatchLookupJoin::new(
468 new_logical_join,
469 new_predicate,
470 table_desc,
471 new_scan_output_column_ids,
472 lookup_prefix_len,
473 false,
474 as_of,
475 asof_desc,
476 )))
477 }
478
479 pub fn decompose(self) -> (PlanRef, PlanRef, Condition, JoinType, Vec<usize>) {
480 self.core.decompose()
481 }
482}
483
484impl PlanTreeNodeBinary for LogicalJoin {
485 fn left(&self) -> PlanRef {
486 self.core.left.clone()
487 }
488
489 fn right(&self) -> PlanRef {
490 self.core.right.clone()
491 }
492
493 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
494 Self::with_core(generic::Join {
495 left,
496 right,
497 ..self.core.clone()
498 })
499 }
500
501 fn rewrite_with_left_right(
502 &self,
503 left: PlanRef,
504 left_col_change: ColIndexMapping,
505 right: PlanRef,
506 right_col_change: ColIndexMapping,
507 ) -> (Self, ColIndexMapping) {
508 let (new_on, new_output_indices) = {
509 let (mut map, _) = left_col_change.clone().into_parts();
510 let (mut right_map, _) = right_col_change.clone().into_parts();
511 for i in right_map.iter_mut().flatten() {
512 *i += left.schema().len();
513 }
514 map.append(&mut right_map);
515 let mut mapping = ColIndexMapping::new(map, left.schema().len() + right.schema().len());
516
517 let new_output_indices = self
518 .output_indices()
519 .iter()
520 .map(|&i| mapping.map(i))
521 .collect::<Vec<_>>();
522 let new_on = self.on().clone().rewrite_expr(&mut mapping);
523 (new_on, new_output_indices)
524 };
525
526 let join = Self::with_output_indices(
527 left,
528 right,
529 self.join_type(),
530 new_on,
531 new_output_indices.clone(),
532 );
533
534 let new_i2o = ColIndexMapping::with_remaining_columns(
535 &new_output_indices,
536 join.internal_column_num(),
537 );
538
539 let old_o2i = self.core.o2i_col_mapping();
540
541 let old_o2l = old_o2i
542 .composite(&self.core.i2l_col_mapping())
543 .composite(&left_col_change);
544 let old_o2r = old_o2i
545 .composite(&self.core.i2r_col_mapping())
546 .composite(&right_col_change);
547 let new_l2o = join.core.l2i_col_mapping().composite(&new_i2o);
548 let new_r2o = join.core.r2i_col_mapping().composite(&new_i2o);
549
550 let out_col_change = old_o2l
551 .composite(&new_l2o)
552 .union(&old_o2r.composite(&new_r2o));
553 (join, out_col_change)
554 }
555}
556
557impl_plan_tree_node_for_binary! { LogicalJoin }
558
559impl ColPrunable for LogicalJoin {
560 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
561 let required_cols = required_cols
563 .iter()
564 .map(|i| self.output_indices()[*i])
565 .collect_vec();
566 let left_len = self.left().schema().fields.len();
567
568 let total_len = self.left().schema().len() + self.right().schema().len();
569 let mut resized_required_cols = FixedBitSet::with_capacity(total_len);
570
571 required_cols.iter().for_each(|&i| {
572 if self.is_right_join() {
573 resized_required_cols.insert(left_len + i);
574 } else {
575 resized_required_cols.insert(i);
576 }
577 });
578
579 let mut visitor = CollectInputRef::new(resized_required_cols);
582 self.on().visit_expr(&mut visitor);
583 let left_right_required_cols = FixedBitSet::from(visitor).ones().collect_vec();
584
585 let mut left_required_cols = Vec::new();
586 let mut right_required_cols = Vec::new();
587 left_right_required_cols.iter().for_each(|&i| {
588 if i < left_len {
589 left_required_cols.push(i);
590 } else {
591 right_required_cols.push(i - left_len);
592 }
593 });
594
595 let mut on = self.on().clone();
596 let mut mapping =
597 ColIndexMapping::with_remaining_columns(&left_right_required_cols, total_len);
598 on = on.rewrite_expr(&mut mapping);
599
600 let new_output_indices = {
601 let required_inputs_in_output = if self.is_left_join() {
602 &left_required_cols
603 } else if self.is_right_join() {
604 &right_required_cols
605 } else {
606 &left_right_required_cols
607 };
608
609 let mapping =
610 ColIndexMapping::with_remaining_columns(required_inputs_in_output, total_len);
611 required_cols.iter().map(|&i| mapping.map(i)).collect_vec()
612 };
613
614 LogicalJoin::with_output_indices(
615 self.left().prune_col(&left_required_cols, ctx),
616 self.right().prune_col(&right_required_cols, ctx),
617 self.join_type(),
618 on,
619 new_output_indices,
620 )
621 .into()
622 }
623}
624
625impl ExprRewritable for LogicalJoin {
626 fn has_rewritable_expr(&self) -> bool {
627 true
628 }
629
630 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
631 let mut core = self.core.clone();
632 core.rewrite_exprs(r);
633 Self {
634 base: self.base.clone_with_new_plan_id(),
635 core,
636 }
637 .into()
638 }
639}
640
641impl ExprVisitable for LogicalJoin {
642 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
643 self.core.visit_exprs(v);
644 }
645}
646
647fn derive_predicate_from_eq_condition(
665 expr: &ExprImpl,
666 eq_condition: &EqJoinPredicate,
667 col_num: usize,
668 expr_is_left: bool,
669) -> Option<ExprImpl> {
670 if expr.is_impure() {
671 return None;
672 }
673 let eq_indices = eq_condition
674 .eq_indexes_typed()
675 .iter()
676 .filter_map(|(l, r)| {
677 if l.return_type() != r.return_type() {
678 None
679 } else if expr_is_left {
680 Some(l.index())
681 } else {
682 Some(r.index())
683 }
684 })
685 .collect_vec();
686 if expr
687 .collect_input_refs(col_num)
688 .ones()
689 .any(|index| !eq_indices.contains(&index))
690 {
691 return None;
693 }
694 let other_side_mapping = if expr_is_left {
697 eq_condition.eq_indexes_typed().into_iter().collect()
698 } else {
699 eq_condition
700 .eq_indexes_typed()
701 .into_iter()
702 .map(|(x, y)| (y, x))
703 .collect()
704 };
705 struct InputRefsRewriter {
706 mapping: HashMap<InputRef, InputRef>,
707 }
708 impl ExprRewriter for InputRefsRewriter {
709 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
710 self.mapping[&input_ref].clone().into()
711 }
712 }
713 Some(
714 InputRefsRewriter {
715 mapping: other_side_mapping,
716 }
717 .rewrite_expr(expr.clone()),
718 )
719}
720
721struct LookupJoinPredicateRewriter {
723 offset: usize,
724 mapping: Vec<usize>,
725}
726impl ExprRewriter for LookupJoinPredicateRewriter {
727 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
728 if input_ref.index() < self.offset {
729 input_ref.into()
730 } else {
731 InputRef::new(
732 self.mapping[input_ref.index() - self.offset] + self.offset,
733 input_ref.return_type(),
734 )
735 .into()
736 }
737 }
738}
739
740struct LookupJoinScanPredicateRewriter {
742 offset: usize,
743}
744impl ExprRewriter for LookupJoinScanPredicateRewriter {
745 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
746 InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
747 }
748}
749
750impl PredicatePushdown for LogicalJoin {
751 fn predicate_pushdown(
775 &self,
776 predicate: Condition,
777 ctx: &mut PredicatePushdownContext,
778 ) -> PlanRef {
779 let mut predicate = {
781 let mut mapping = self.core.o2i_col_mapping();
782 predicate.rewrite_expr(&mut mapping)
783 };
784
785 let left_col_num = self.left().schema().len();
786 let right_col_num = self.right().schema().len();
787 let join_type = LogicalJoin::simplify_outer(&predicate, left_col_num, self.join_type());
788
789 let push_down_temporal_predicate = !self.should_be_temporal_join();
790
791 let (left_from_filter, right_from_filter, on) = push_down_into_join(
792 &mut predicate,
793 left_col_num,
794 right_col_num,
795 join_type,
796 push_down_temporal_predicate,
797 );
798
799 let mut new_on = self.on().clone().and(on);
800 let (left_from_on, right_from_on) = push_down_join_condition(
801 &mut new_on,
802 left_col_num,
803 right_col_num,
804 join_type,
805 push_down_temporal_predicate,
806 );
807
808 let left_predicate = left_from_filter.and(left_from_on);
809 let right_predicate = right_from_filter.and(right_from_on);
810
811 let eq_condition = EqJoinPredicate::create(left_col_num, right_col_num, new_on.clone());
813
814 let right_from_left = if matches!(
816 join_type,
817 JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi | JoinType::LeftSemi
818 ) {
819 Condition {
820 conjunctions: left_predicate
821 .conjunctions
822 .iter()
823 .filter_map(|expr| {
824 derive_predicate_from_eq_condition(expr, &eq_condition, left_col_num, true)
825 })
826 .collect(),
827 }
828 } else {
829 Condition::true_cond()
830 };
831
832 let left_from_right = if matches!(
834 join_type,
835 JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi | JoinType::RightSemi
836 ) {
837 Condition {
838 conjunctions: right_predicate
839 .conjunctions
840 .iter()
841 .filter_map(|expr| {
842 derive_predicate_from_eq_condition(
843 expr,
844 &eq_condition,
845 right_col_num,
846 false,
847 )
848 })
849 .collect(),
850 }
851 } else {
852 Condition::true_cond()
853 };
854
855 let left_predicate = left_predicate.and(left_from_right);
856 let right_predicate = right_predicate.and(right_from_left);
857
858 let new_left = self.left().predicate_pushdown(left_predicate, ctx);
859 let new_right = self.right().predicate_pushdown(right_predicate, ctx);
860 let new_join = LogicalJoin::with_output_indices(
861 new_left,
862 new_right,
863 join_type,
864 new_on,
865 self.output_indices().clone(),
866 );
867
868 let mut mapping = self.core.i2o_col_mapping();
869 predicate = predicate.rewrite_expr(&mut mapping);
870 LogicalFilter::create(new_join.into(), predicate)
871 }
872}
873
874impl LogicalJoin {
875 fn get_stream_input_for_hash_join(
876 &self,
877 predicate: &EqJoinPredicate,
878 ctx: &mut ToStreamContext,
879 ) -> Result<(PlanRef, PlanRef)> {
880 use super::stream::prelude::*;
881
882 let mut right = self.right().to_stream_with_dist_required(
883 &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()),
884 ctx,
885 )?;
886 let mut left = self.left();
887
888 let r2l = predicate.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
889 let l2r = predicate.l2r_eq_columns_mapping(left.schema().len(), right.schema().len());
890
891 let right_dist = right.distribution();
892 match right_dist {
893 Distribution::HashShard(_) => {
894 let left_dist = r2l
895 .rewrite_required_distribution(&RequiredDist::PhysicalDist(right_dist.clone()));
896 left = left.to_stream_with_dist_required(&left_dist, ctx)?;
897 }
898 Distribution::UpstreamHashShard(_, _) => {
899 left = left.to_stream_with_dist_required(
900 &RequiredDist::shard_by_key(
901 self.left().schema().len(),
902 &predicate.left_eq_indexes(),
903 ),
904 ctx,
905 )?;
906 let left_dist = left.distribution();
907 match left_dist {
908 Distribution::HashShard(_) => {
909 let right_dist = l2r.rewrite_required_distribution(
910 &RequiredDist::PhysicalDist(left_dist.clone()),
911 );
912 right = right_dist.enforce_if_not_satisfies(right, &Order::any())?
913 }
914 Distribution::UpstreamHashShard(_, _) => {
915 left = RequiredDist::hash_shard(&predicate.left_eq_indexes())
916 .enforce_if_not_satisfies(left, &Order::any())?;
917 right = RequiredDist::hash_shard(&predicate.right_eq_indexes())
918 .enforce_if_not_satisfies(right, &Order::any())?;
919 }
920 _ => unreachable!(),
921 }
922 }
923 _ => unreachable!(),
924 }
925 Ok((left, right))
926 }
927
928 fn to_stream_hash_join(
929 &self,
930 predicate: EqJoinPredicate,
931 ctx: &mut ToStreamContext,
932 ) -> Result<PlanRef> {
933 use super::stream::prelude::*;
934
935 assert!(predicate.has_eq());
936 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
937
938 let logical_join = self.clone_with_left_right(left, right);
939
940 let stream_hash_join = StreamHashJoin::new(logical_join.core.clone(), predicate.clone());
949
950 let force_filter_inside_join = self
951 .base
952 .ctx()
953 .session_ctx()
954 .config()
955 .streaming_force_filter_inside_join();
956
957 let pull_filter = self.join_type() == JoinType::Inner
958 && stream_hash_join.eq_join_predicate().has_non_eq()
959 && stream_hash_join.inequality_pairs().is_empty()
960 && (!force_filter_inside_join);
961 if pull_filter {
962 let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
963
964 let logical_join = logical_join.clone_with_output_indices(default_indices.clone());
966 let eq_cond = EqJoinPredicate::new(
967 Condition::true_cond(),
968 predicate.eq_keys().to_vec(),
969 self.left().schema().len(),
970 self.right().schema().len(),
971 );
972 let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
973 let hash_join = StreamHashJoin::new(logical_join.core, eq_cond).into();
974 let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
975 let plan = StreamFilter::new(logical_filter).into();
976 if self.output_indices() != &default_indices {
977 let logical_project = generic::Project::with_mapping(
978 plan,
979 ColIndexMapping::with_remaining_columns(
980 self.output_indices(),
981 self.internal_column_num(),
982 ),
983 );
984 Ok(StreamProject::new(logical_project).into())
985 } else {
986 Ok(plan)
987 }
988 } else {
989 Ok(stream_hash_join.into())
990 }
991 }
992
993 fn should_be_temporal_join(&self) -> bool {
994 let right = self.right();
995 if let Some(logical_scan) = right.as_logical_scan() {
996 matches!(logical_scan.as_of(), Some(AsOf::ProcessTime))
997 } else {
998 false
999 }
1000 }
1001
1002 fn to_stream_temporal_join_with_index_selection(
1003 &self,
1004 predicate: EqJoinPredicate,
1005 ctx: &mut ToStreamContext,
1006 ) -> Result<PlanRef> {
1007 let right = self.right();
1009 let logical_scan: &LogicalScan = right.as_logical_scan().unwrap();
1011
1012 let mut result_plan: Result<StreamTemporalJoin> =
1014 self.to_stream_temporal_join(predicate.clone(), ctx);
1015 if let Ok(temporal_join) = &result_plan
1017 && temporal_join.eq_join_predicate().eq_indexes().len()
1018 == logical_scan.primary_key().len()
1019 {
1020 return result_plan.map(|x| x.into());
1021 }
1022 let indexes = logical_scan.indexes();
1023 for index in indexes {
1024 if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
1026 let index_scan: PlanRef = index_scan.into();
1027 let that = self.clone_with_left_right(self.left(), index_scan.clone());
1028 if let Ok(temporal_join) = that.to_stream_temporal_join(predicate.clone(), ctx) {
1029 match &result_plan {
1030 Err(_) => result_plan = Ok(temporal_join),
1031 Ok(prev_temporal_join) => {
1032 if prev_temporal_join.eq_join_predicate().eq_indexes().len()
1034 < temporal_join.eq_join_predicate().eq_indexes().len()
1035 {
1036 result_plan = Ok(temporal_join)
1037 }
1038 }
1039 }
1040 }
1041 }
1042 }
1043
1044 result_plan.map(|x| x.into())
1045 }
1046
1047 fn check_temporal_rhs(right: &PlanRef) -> Result<&LogicalScan> {
1048 let Some(logical_scan) = right.as_logical_scan() else {
1049 return Err(RwError::from(ErrorCode::NotSupported(
1050 "Temporal join requires a table scan as its lookup table".into(),
1051 "Please provide a table scan".into(),
1052 )));
1053 };
1054
1055 if !matches!(logical_scan.as_of(), Some(AsOf::ProcessTime)) {
1056 return Err(RwError::from(ErrorCode::NotSupported(
1057 "Temporal join requires a table defined as temporal table".into(),
1058 "Please use FOR SYSTEM_TIME AS OF PROCTIME() syntax".into(),
1059 )));
1060 }
1061 Ok(logical_scan)
1062 }
1063
1064 fn temporal_join_scan_predicate_pull_up(
1065 logical_scan: &LogicalScan,
1066 predicate: EqJoinPredicate,
1067 output_indices: &[usize],
1068 left_schema_len: usize,
1069 ) -> Result<(StreamTableScan, EqJoinPredicate, Condition, Vec<usize>)> {
1070 let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up();
1072 let o2r = if let Some(project_expr) = project_expr {
1074 project_expr
1075 .into_iter()
1076 .map(|x| x.as_input_ref().unwrap().index)
1077 .collect_vec()
1078 } else {
1079 (0..logical_scan.output_col_idx().len()).collect_vec()
1080 };
1081 let mut join_predicate_rewriter = LookupJoinPredicateRewriter {
1082 offset: left_schema_len,
1083 mapping: o2r.clone(),
1084 };
1085
1086 let new_eq_cond = predicate
1087 .eq_cond()
1088 .rewrite_expr(&mut join_predicate_rewriter);
1089
1090 let mut scan_predicate_rewriter = LookupJoinScanPredicateRewriter {
1091 offset: left_schema_len,
1092 };
1093
1094 let new_other_cond = predicate
1095 .other_cond()
1096 .clone()
1097 .rewrite_expr(&mut join_predicate_rewriter)
1098 .and(scan_predicate.rewrite_expr(&mut scan_predicate_rewriter));
1099
1100 let new_join_on = new_eq_cond.and(new_other_cond);
1101
1102 let new_predicate = EqJoinPredicate::create(
1103 left_schema_len,
1104 new_scan.schema().len(),
1105 new_join_on.clone(),
1106 );
1107
1108 let new_join_output_indices = output_indices
1111 .iter()
1112 .map(|&x| {
1113 if x < left_schema_len {
1114 x
1115 } else {
1116 o2r[x - left_schema_len] + left_schema_len
1117 }
1118 })
1119 .collect_vec();
1120 let new_stream_table_scan =
1122 StreamTableScan::new_with_stream_scan_type(new_scan, StreamScanType::UpstreamOnly);
1123 Ok((
1124 new_stream_table_scan,
1125 new_predicate,
1126 new_join_on,
1127 new_join_output_indices,
1128 ))
1129 }
1130
1131 fn to_stream_temporal_join(
1132 &self,
1133 predicate: EqJoinPredicate,
1134 ctx: &mut ToStreamContext,
1135 ) -> Result<StreamTemporalJoin> {
1136 use super::stream::prelude::*;
1137
1138 assert!(predicate.has_eq());
1139
1140 let right = self.right();
1141
1142 let logical_scan = Self::check_temporal_rhs(&right)?;
1143
1144 let table_desc = logical_scan.table_desc();
1145 let output_column_ids = logical_scan.output_column_ids();
1146
1147 let order_col_ids = table_desc.order_column_ids();
1150 let order_key = table_desc.order_column_indices();
1151 let dist_key = table_desc.distribution_key.clone();
1152
1153 let mut dist_key_in_order_key_pos = vec![];
1154 for d in dist_key {
1155 let pos = order_key
1156 .iter()
1157 .position(|&x| x == d)
1158 .expect("dist_key must in order_key");
1159 dist_key_in_order_key_pos.push(pos);
1160 }
1161 let shortest_prefix_len = dist_key_in_order_key_pos
1163 .iter()
1164 .max()
1165 .map_or(0, |pos| pos + 1);
1166
1167 let mut reorder_idx = Vec::with_capacity(shortest_prefix_len);
1169 for order_col_id in order_col_ids {
1170 let mut found = false;
1171 for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() {
1172 if order_col_id == output_column_ids[eq_idx] {
1173 reorder_idx.push(i);
1174 found = true;
1175 break;
1176 }
1177 }
1178 if !found {
1179 break;
1180 }
1181 }
1182 if reorder_idx.len() < shortest_prefix_len {
1183 return Err(RwError::from(ErrorCode::NotSupported(
1185 "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(),
1186 "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(),
1187 )));
1188 }
1189 let lookup_prefix_len = reorder_idx.len();
1190 let predicate = predicate.reorder(&reorder_idx);
1191
1192 let required_dist = if dist_key_in_order_key_pos.is_empty() {
1193 RequiredDist::single()
1194 } else {
1195 let left_eq_indexes = predicate.left_eq_indexes();
1196 let left_dist_key = dist_key_in_order_key_pos
1197 .iter()
1198 .map(|pos| left_eq_indexes[*pos])
1199 .collect_vec();
1200
1201 RequiredDist::hash_shard(&left_dist_key)
1202 };
1203
1204 let left = self.left().to_stream(ctx)?;
1205 let left = required_dist.enforce(left, &Order::any());
1207
1208 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1209 Self::temporal_join_scan_predicate_pull_up(
1210 logical_scan,
1211 predicate,
1212 self.output_indices(),
1213 self.left().schema().len(),
1214 )?;
1215
1216 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1217 if !new_predicate.has_eq() {
1218 return Err(RwError::from(ErrorCode::NotSupported(
1219 "Temporal join requires a non trivial join condition".into(),
1220 "Please remove the false condition of the join".into(),
1221 )));
1222 }
1223
1224 let new_logical_join = generic::Join::new(
1226 left,
1227 right,
1228 new_join_on,
1229 self.join_type(),
1230 new_join_output_indices,
1231 );
1232
1233 let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len);
1234
1235 Ok(StreamTemporalJoin::new(
1236 new_logical_join,
1237 new_predicate,
1238 false,
1239 ))
1240 }
1241
1242 fn to_stream_nested_loop_temporal_join(
1243 &self,
1244 predicate: EqJoinPredicate,
1245 ctx: &mut ToStreamContext,
1246 ) -> Result<PlanRef> {
1247 use super::stream::prelude::*;
1248 assert!(!predicate.has_eq());
1249
1250 let left = self.left().to_stream_with_dist_required(
1251 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1252 ctx,
1253 )?;
1254 assert!(left.as_stream_exchange().is_some());
1255
1256 if self.join_type() != JoinType::Inner {
1257 return Err(RwError::from(ErrorCode::NotSupported(
1258 "Temporal join requires an inner join".into(),
1259 "Please use an inner join".into(),
1260 )));
1261 }
1262
1263 if !left.append_only() {
1264 return Err(RwError::from(ErrorCode::NotSupported(
1265 "Nested-loop Temporal join requires the left hash side to be append only".into(),
1266 "Please ensure the left hash side is append only".into(),
1267 )));
1268 }
1269
1270 let right = self.right();
1271 let logical_scan = Self::check_temporal_rhs(&right)?;
1272
1273 let (new_stream_table_scan, new_predicate, new_join_on, new_join_output_indices) =
1274 Self::temporal_join_scan_predicate_pull_up(
1275 logical_scan,
1276 predicate,
1277 self.output_indices(),
1278 self.left().schema().len(),
1279 )?;
1280
1281 let right = RequiredDist::no_shuffle(new_stream_table_scan.into());
1282
1283 let new_logical_join = generic::Join::new(
1285 left,
1286 right,
1287 new_join_on,
1288 self.join_type(),
1289 new_join_output_indices,
1290 );
1291
1292 Ok(StreamTemporalJoin::new(new_logical_join, new_predicate, true).into())
1293 }
1294
1295 fn to_stream_dynamic_filter(
1296 &self,
1297 predicate: Condition,
1298 ctx: &mut ToStreamContext,
1299 ) -> Result<Option<PlanRef>> {
1300 use super::stream::prelude::*;
1301
1302 if !matches!(self.join_type(), JoinType::Inner | JoinType::LeftSemi) {
1308 return Ok(None);
1309 }
1310
1311 if !self.right().max_one_row() {
1313 return Ok(None);
1314 }
1315 if self.right().schema().len() != 1 {
1316 return Ok(None);
1317 }
1318
1319 if predicate.conjunctions.len() > 1 {
1321 return Ok(None);
1322 }
1323 let expr: ExprImpl = predicate.into();
1324 let (left_ref, comparator, right_ref) = match expr.as_comparison_cond() {
1325 Some(v) => v,
1326 None => return Ok(None),
1327 };
1328
1329 let condition_cross_inputs = left_ref.index < self.left().schema().len()
1330 && right_ref.index == self.left().schema().len() ;
1331 if !condition_cross_inputs {
1332 return Ok(None);
1334 }
1335
1336 if self.left().schema().fields()[left_ref.index].data_type
1338 != self.right().schema().fields()[0].data_type
1339 {
1340 return Ok(None);
1341 }
1342
1343 let all_output_from_left = self
1345 .output_indices()
1346 .iter()
1347 .all(|i| *i < self.left().schema().len());
1348 if !all_output_from_left {
1349 return Ok(None);
1350 }
1351
1352 let left = self.left().to_stream(ctx)?;
1353 let right = self.right().to_stream_with_dist_required(
1354 &RequiredDist::PhysicalDist(Distribution::Broadcast),
1355 ctx,
1356 )?;
1357
1358 assert!(right.as_stream_exchange().is_some());
1359 assert_eq!(
1360 *right.inputs().iter().exactly_one().unwrap().distribution(),
1361 Distribution::Single
1362 );
1363
1364 let core = DynamicFilter::new(comparator, left_ref.index, left, right);
1365 let plan = StreamDynamicFilter::new(core).into();
1366 if self
1368 .output_indices()
1369 .iter()
1370 .copied()
1371 .ne(0..self.left().schema().len())
1372 {
1373 let logical_project = generic::Project::with_mapping(
1376 plan,
1377 ColIndexMapping::with_remaining_columns(
1378 self.output_indices(),
1379 self.left().schema().len(),
1380 ),
1381 );
1382 Ok(Some(StreamProject::new(logical_project).into()))
1383 } else {
1384 Ok(Some(plan))
1385 }
1386 }
1387
1388 pub fn index_lookup_join_to_batch_lookup_join(&self) -> Result<PlanRef> {
1389 let predicate = EqJoinPredicate::create(
1390 self.left().schema().len(),
1391 self.right().schema().len(),
1392 self.on().clone(),
1393 );
1394 assert!(predicate.has_eq());
1395
1396 let mut logical_join = self.core.clone();
1397 logical_join.left = logical_join.left.to_batch()?;
1398 logical_join.right = logical_join.right.to_batch()?;
1399
1400 Ok(self
1401 .to_batch_lookup_join(predicate, logical_join)?
1402 .expect("Fail to convert to lookup join")
1403 .into())
1404 }
1405
1406 fn to_stream_asof_join(
1407 &self,
1408 predicate: EqJoinPredicate,
1409 ctx: &mut ToStreamContext,
1410 ) -> Result<PlanRef> {
1411 use super::stream::prelude::*;
1412
1413 if predicate.eq_keys().is_empty() {
1414 return Err(ErrorCode::InvalidInputSyntax(
1415 "AsOf join requires at least 1 equal condition".to_owned(),
1416 )
1417 .into());
1418 }
1419
1420 let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?;
1421 let left_len = left.schema().len();
1422 let logical_join = self.clone_with_left_right(left, right);
1423
1424 let inequality_desc =
1425 Self::get_inequality_desc_from_predicate(predicate.other_cond().clone(), left_len)?;
1426
1427 Ok(StreamAsOfJoin::new(logical_join.core.clone(), predicate, inequality_desc).into())
1428 }
1429
1430 fn to_batch_hash_join(
1432 &self,
1433 logical_join: generic::Join<PlanRef>,
1434 predicate: EqJoinPredicate,
1435 ) -> Result<PlanRef> {
1436 use super::batch::prelude::*;
1437
1438 let left_schema_len = logical_join.left.schema().len();
1439 let asof_desc = self
1440 .is_asof_join()
1441 .then(|| {
1442 Self::get_inequality_desc_from_predicate(
1443 predicate.other_cond().clone(),
1444 left_schema_len,
1445 )
1446 })
1447 .transpose()?;
1448
1449 let batch_join = BatchHashJoin::new(logical_join, predicate, asof_desc);
1450 Ok(batch_join.into())
1451 }
1452
1453 pub fn get_inequality_desc_from_predicate(
1454 predicate: Condition,
1455 left_input_len: usize,
1456 ) -> Result<AsOfJoinDesc> {
1457 let expr: ExprImpl = predicate.into();
1458 if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() {
1459 if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len
1460 {
1461 Ok(AsOfJoinDesc {
1462 left_idx: left_input_ref.index() as u32,
1463 right_idx: (right_input_ref.index() - left_input_len) as u32,
1464 inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(),
1465 })
1466 } else {
1467 bail!("inequal condition from the same side should be push down in optimizer");
1468 }
1469 } else {
1470 Err(ErrorCode::InvalidInputSyntax(
1471 "AsOf join requires exactly 1 ineuquality condition".to_owned(),
1472 )
1473 .into())
1474 }
1475 }
1476
1477 fn expr_type_to_comparison_type(expr_type: PbType) -> Result<PbAsOfJoinInequalityType> {
1478 match expr_type {
1479 PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt),
1480 PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe),
1481 PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt),
1482 PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe),
1483 _ => Err(ErrorCode::InvalidInputSyntax(format!(
1484 "Invalid comparison type: {}",
1485 expr_type.as_str_name()
1486 ))
1487 .into()),
1488 }
1489 }
1490}
1491
1492impl ToBatch for LogicalJoin {
1493 fn to_batch(&self) -> Result<PlanRef> {
1494 let predicate = EqJoinPredicate::create(
1495 self.left().schema().len(),
1496 self.right().schema().len(),
1497 self.on().clone(),
1498 );
1499
1500 let mut logical_join = self.core.clone();
1501 logical_join.left = logical_join.left.to_batch()?;
1502 logical_join.right = logical_join.right.to_batch()?;
1503
1504 let ctx = self.base.ctx();
1505 let config = ctx.session_ctx().config();
1506
1507 if predicate.has_eq() {
1508 if !predicate.eq_keys_are_type_aligned() {
1509 return Err(ErrorCode::InternalError(format!(
1510 "Join eq keys are not aligned for predicate: {predicate:?}"
1511 ))
1512 .into());
1513 }
1514 if config.batch_enable_lookup_join() {
1515 if let Some(lookup_join) = self.to_batch_lookup_join_with_index_selection(
1516 predicate.clone(),
1517 logical_join.clone(),
1518 )? {
1519 return Ok(lookup_join.into());
1520 }
1521 }
1522 self.to_batch_hash_join(logical_join, predicate)
1523 } else if self.is_asof_join() {
1524 return Err(ErrorCode::InvalidInputSyntax(
1525 "AsOf join requires at least 1 equal condition".to_owned(),
1526 )
1527 .into());
1528 } else {
1529 Ok(BatchNestedLoopJoin::new(logical_join).into())
1531 }
1532 }
1533}
1534
1535impl ToStream for LogicalJoin {
1536 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1537 if self
1538 .on()
1539 .conjunctions
1540 .iter()
1541 .any(|cond| cond.count_nows() > 0)
1542 {
1543 return Err(ErrorCode::NotSupported(
1544 "optimizer has tried to separate the temporal predicate(with now() expression) from the on condition, but it still reminded in on join's condition. Considering move it into WHERE clause?".to_owned(),
1545 "please refer to https://www.risingwave.dev/docs/current/sql-pattern-temporal-filters/ for more information".to_owned()).into());
1546 }
1547
1548 let predicate = EqJoinPredicate::create(
1549 self.left().schema().len(),
1550 self.right().schema().len(),
1551 self.on().clone(),
1552 );
1553
1554 if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter {
1555 self.to_stream_asof_join(predicate, ctx)
1556 } else if predicate.has_eq() {
1557 if !predicate.eq_keys_are_type_aligned() {
1558 return Err(ErrorCode::InternalError(format!(
1559 "Join eq keys are not aligned for predicate: {predicate:?}"
1560 ))
1561 .into());
1562 }
1563
1564 if self.should_be_temporal_join() {
1565 self.to_stream_temporal_join_with_index_selection(predicate, ctx)
1566 } else {
1567 self.to_stream_hash_join(predicate, ctx)
1568 }
1569 } else if self.should_be_temporal_join() {
1570 self.to_stream_nested_loop_temporal_join(predicate, ctx)
1571 } else if let Some(dynamic_filter) =
1572 self.to_stream_dynamic_filter(self.on().clone(), ctx)?
1573 {
1574 Ok(dynamic_filter)
1575 } else {
1576 Err(RwError::from(ErrorCode::NotSupported(
1577 "streaming nested-loop join".to_owned(),
1578 "The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. \
1579 Consider rewriting the query to use dynamic filter as a substitute if possible.\n\
1580 See also: https://docs.risingwave.com/docs/current/sql-pattern-dynamic-filters/".to_owned(),
1581 )))
1582 }
1583 }
1584
1585 fn logical_rewrite_for_stream(
1586 &self,
1587 ctx: &mut RewriteStreamContext,
1588 ) -> Result<(PlanRef, ColIndexMapping)> {
1589 let (left, left_col_change) = self.left().logical_rewrite_for_stream(ctx)?;
1590 let left_len = left.schema().len();
1591 let (right, right_col_change) = self.right().logical_rewrite_for_stream(ctx)?;
1592 let (join, out_col_change) = self.rewrite_with_left_right(
1593 left.clone(),
1594 left_col_change,
1595 right.clone(),
1596 right_col_change,
1597 );
1598
1599 let mapping = ColIndexMapping::with_remaining_columns(
1600 join.output_indices(),
1601 join.internal_column_num(),
1602 );
1603
1604 let l2o = join.core.l2i_col_mapping().composite(&mapping);
1605 let r2o = join.core.r2i_col_mapping().composite(&mapping);
1606
1607 let mut left_to_add = left
1609 .expect_stream_key()
1610 .iter()
1611 .cloned()
1612 .filter(|i| l2o.try_map(*i).is_none())
1613 .collect_vec();
1614
1615 let mut right_to_add = right
1616 .expect_stream_key()
1617 .iter()
1618 .filter(|&&i| r2o.try_map(i).is_none())
1619 .map(|&i| i + left_len)
1620 .collect_vec();
1621
1622 let right_len = right.schema().len();
1625 let eq_predicate = EqJoinPredicate::create(left_len, right_len, join.on().clone());
1626
1627 let either_or_both = self.core.add_which_join_key_to_pk();
1628
1629 for (lk, rk) in eq_predicate.eq_indexes() {
1630 match either_or_both {
1631 EitherOrBoth::Left(_) => {
1632 if l2o.try_map(lk).is_none() {
1633 left_to_add.push(lk);
1634 }
1635 }
1636 EitherOrBoth::Right(_) => {
1637 if r2o.try_map(rk).is_none() {
1638 right_to_add.push(rk + left_len)
1639 }
1640 }
1641 EitherOrBoth::Both(_, _) => {
1642 if l2o.try_map(lk).is_none() {
1643 left_to_add.push(lk);
1644 }
1645 if r2o.try_map(rk).is_none() {
1646 right_to_add.push(rk + left_len)
1647 }
1648 }
1649 };
1650 }
1651 let left_to_add = left_to_add.into_iter().unique();
1652 let right_to_add = right_to_add.into_iter().unique();
1653 let mut new_output_indices = join.output_indices().clone();
1656 if !join.is_right_join() {
1657 new_output_indices.extend(left_to_add);
1658 }
1659 if !join.is_left_join() {
1660 new_output_indices.extend(right_to_add);
1661 }
1662
1663 let join_with_pk = join.clone_with_output_indices(new_output_indices);
1664
1665 let plan = if join_with_pk.join_type() == JoinType::FullOuter {
1666 let l2o = join_with_pk
1669 .core
1670 .l2i_col_mapping()
1671 .composite(&join_with_pk.core.i2o_col_mapping());
1672 let r2o = join_with_pk
1673 .core
1674 .r2i_col_mapping()
1675 .composite(&join_with_pk.core.i2o_col_mapping());
1676 let left_right_stream_keys = join_with_pk
1677 .left()
1678 .expect_stream_key()
1679 .iter()
1680 .map(|i| l2o.map(*i))
1681 .chain(
1682 join_with_pk
1683 .right()
1684 .expect_stream_key()
1685 .iter()
1686 .map(|i| r2o.map(*i)),
1687 )
1688 .collect_vec();
1689 let plan: PlanRef = join_with_pk.into();
1690 LogicalFilter::filter_out_all_null_keys(plan, &left_right_stream_keys)
1691 } else {
1692 join_with_pk.into()
1693 };
1694
1695 Ok((plan, out_col_change))
1697 }
1698}
1699
1700#[cfg(test)]
1701mod tests {
1702
1703 use std::collections::HashSet;
1704
1705 use risingwave_common::catalog::{Field, Schema};
1706 use risingwave_common::types::{DataType, Datum};
1707 use risingwave_pb::expr::expr_node::Type;
1708
1709 use super::*;
1710 use crate::expr::{FunctionCall, Literal, assert_eq_input_ref};
1711 use crate::optimizer::optimizer_context::OptimizerContext;
1712 use crate::optimizer::plan_node::LogicalValues;
1713 use crate::optimizer::property::FunctionalDependency;
1714
1715 #[tokio::test]
1729 async fn test_prune_join() {
1730 let ty = DataType::Int32;
1731 let ctx = OptimizerContext::mock().await;
1732 let fields: Vec<Field> = (1..7)
1733 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1734 .collect();
1735 let left = LogicalValues::new(
1736 vec![],
1737 Schema {
1738 fields: fields[0..3].to_vec(),
1739 },
1740 ctx.clone(),
1741 );
1742 let right = LogicalValues::new(
1743 vec![],
1744 Schema {
1745 fields: fields[3..6].to_vec(),
1746 },
1747 ctx,
1748 );
1749 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1750 FunctionCall::new(
1751 Type::Equal,
1752 vec![
1753 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1754 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1755 ],
1756 )
1757 .unwrap(),
1758 ));
1759 let join_type = JoinType::Inner;
1760 let join: PlanRef = LogicalJoin::new(
1761 left.into(),
1762 right.into(),
1763 join_type,
1764 Condition::with_expr(on),
1765 )
1766 .into();
1767
1768 let required_cols = vec![2, 3];
1770 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1771
1772 let join = plan.as_logical_join().unwrap();
1774 assert_eq!(join.schema().fields().len(), 2);
1775 assert_eq!(join.schema().fields()[0], fields[2]);
1776 assert_eq!(join.schema().fields()[1], fields[3]);
1777
1778 let expr: ExprImpl = join.on().clone().into();
1779 let call = expr.as_function_call().unwrap();
1780 assert_eq_input_ref!(&call.inputs()[0], 0);
1781 assert_eq_input_ref!(&call.inputs()[1], 2);
1782
1783 let left = join.left();
1784 let left = left.as_logical_values().unwrap();
1785 assert_eq!(left.schema().fields(), &fields[1..3]);
1786 let right = join.right();
1787 let right = right.as_logical_values().unwrap();
1788 assert_eq!(right.schema().fields(), &fields[3..4]);
1789 }
1790
1791 #[tokio::test]
1793 async fn test_prune_semi_join() {
1794 let ty = DataType::Int32;
1795 let ctx = OptimizerContext::mock().await;
1796 let fields: Vec<Field> = (1..7)
1797 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1798 .collect();
1799 let left = LogicalValues::new(
1800 vec![],
1801 Schema {
1802 fields: fields[0..3].to_vec(),
1803 },
1804 ctx.clone(),
1805 );
1806 let right = LogicalValues::new(
1807 vec![],
1808 Schema {
1809 fields: fields[3..6].to_vec(),
1810 },
1811 ctx,
1812 );
1813 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1814 FunctionCall::new(
1815 Type::Equal,
1816 vec![
1817 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1818 ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
1819 ],
1820 )
1821 .unwrap(),
1822 ));
1823 for join_type in [
1824 JoinType::LeftSemi,
1825 JoinType::RightSemi,
1826 JoinType::LeftAnti,
1827 JoinType::RightAnti,
1828 ] {
1829 let join = LogicalJoin::new(
1830 left.clone().into(),
1831 right.clone().into(),
1832 join_type,
1833 Condition::with_expr(on.clone()),
1834 );
1835
1836 let offset = if join.is_right_join() { 3 } else { 0 };
1837 let join: PlanRef = join.into();
1838 let required_cols = vec![0];
1840 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1842 let as_plan = plan.as_logical_join().unwrap();
1843 assert_eq!(as_plan.schema().fields().len(), 1);
1845 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1846
1847 let required_cols = vec![0, 1, 2];
1849 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1851 let as_plan = plan.as_logical_join().unwrap();
1852 assert_eq!(as_plan.schema().fields().len(), 3);
1854 assert_eq!(as_plan.schema().fields()[0], fields[offset]);
1855 assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
1856 assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
1857 }
1858 }
1859
1860 #[tokio::test]
1873 async fn test_prune_join_no_project() {
1874 let ty = DataType::Int32;
1875 let ctx = OptimizerContext::mock().await;
1876 let fields: Vec<Field> = (1..7)
1877 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
1878 .collect();
1879 let left = LogicalValues::new(
1880 vec![],
1881 Schema {
1882 fields: fields[0..3].to_vec(),
1883 },
1884 ctx.clone(),
1885 );
1886 let right = LogicalValues::new(
1887 vec![],
1888 Schema {
1889 fields: fields[3..6].to_vec(),
1890 },
1891 ctx,
1892 );
1893 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
1894 FunctionCall::new(
1895 Type::Equal,
1896 vec![
1897 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
1898 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
1899 ],
1900 )
1901 .unwrap(),
1902 ));
1903 let join_type = JoinType::Inner;
1904 let join: PlanRef = LogicalJoin::new(
1905 left.into(),
1906 right.into(),
1907 join_type,
1908 Condition::with_expr(on),
1909 )
1910 .into();
1911
1912 let required_cols = vec![1, 3];
1914 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
1915
1916 let join = plan.as_logical_join().unwrap();
1918 assert_eq!(join.schema().fields().len(), 2);
1919 assert_eq!(join.schema().fields()[0], fields[1]);
1920 assert_eq!(join.schema().fields()[1], fields[3]);
1921
1922 let expr: ExprImpl = join.on().clone().into();
1923 let call = expr.as_function_call().unwrap();
1924 assert_eq_input_ref!(&call.inputs()[0], 0);
1925 assert_eq_input_ref!(&call.inputs()[1], 1);
1926
1927 let left = join.left();
1928 let left = left.as_logical_values().unwrap();
1929 assert_eq!(left.schema().fields(), &fields[1..2]);
1930 let right = join.right();
1931 let right = right.as_logical_values().unwrap();
1932 assert_eq!(right.schema().fields(), &fields[3..4]);
1933 }
1934
1935 #[tokio::test]
1949 async fn test_join_to_batch() {
1950 let ctx = OptimizerContext::mock().await;
1951 let fields: Vec<Field> = (1..7)
1952 .map(|i| Field::with_name(DataType::Int32, format!("v{}", i)))
1953 .collect();
1954 let left = LogicalValues::new(
1955 vec![],
1956 Schema {
1957 fields: fields[0..3].to_vec(),
1958 },
1959 ctx.clone(),
1960 );
1961 let right = LogicalValues::new(
1962 vec![],
1963 Schema {
1964 fields: fields[3..6].to_vec(),
1965 },
1966 ctx,
1967 );
1968
1969 fn input_ref(i: usize) -> ExprImpl {
1970 ExprImpl::InputRef(Box::new(InputRef::new(i, DataType::Int32)))
1971 }
1972 let eq_cond = ExprImpl::FunctionCall(Box::new(
1973 FunctionCall::new(Type::Equal, vec![input_ref(1), input_ref(3)]).unwrap(),
1974 ));
1975 let non_eq_cond = ExprImpl::FunctionCall(Box::new(
1976 FunctionCall::new(
1977 Type::Equal,
1978 vec![
1979 input_ref(2),
1980 ExprImpl::Literal(Box::new(Literal::new(
1981 Datum::Some(42_i32.into()),
1982 DataType::Int32,
1983 ))),
1984 ],
1985 )
1986 .unwrap(),
1987 ));
1988 let on_cond = ExprImpl::FunctionCall(Box::new(
1990 FunctionCall::new(Type::And, vec![eq_cond.clone(), non_eq_cond.clone()]).unwrap(),
1991 ));
1992
1993 let join_type = JoinType::Inner;
1994 let logical_join = LogicalJoin::new(
1995 left.into(),
1996 right.into(),
1997 join_type,
1998 Condition::with_expr(on_cond),
1999 );
2000
2001 let result = logical_join.to_batch().unwrap();
2003
2004 let hash_join = result.as_batch_hash_join().unwrap();
2006 assert_eq!(
2007 ExprImpl::from(hash_join.eq_join_predicate().eq_cond()),
2008 eq_cond
2009 );
2010 assert_eq!(
2011 *hash_join
2012 .eq_join_predicate()
2013 .non_eq_cond()
2014 .conjunctions
2015 .first()
2016 .unwrap(),
2017 non_eq_cond
2018 );
2019 }
2020
2021 #[tokio::test]
2034 #[ignore] async fn test_join_to_stream() {
2037 }
2105 #[tokio::test]
2119 async fn test_join_column_prune_with_order_required() {
2120 let ty = DataType::Int32;
2121 let ctx = OptimizerContext::mock().await;
2122 let fields: Vec<Field> = (1..7)
2123 .map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
2124 .collect();
2125 let left = LogicalValues::new(
2126 vec![],
2127 Schema {
2128 fields: fields[0..3].to_vec(),
2129 },
2130 ctx.clone(),
2131 );
2132 let right = LogicalValues::new(
2133 vec![],
2134 Schema {
2135 fields: fields[3..6].to_vec(),
2136 },
2137 ctx,
2138 );
2139 let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
2140 FunctionCall::new(
2141 Type::Equal,
2142 vec![
2143 ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
2144 ExprImpl::InputRef(Box::new(InputRef::new(3, ty))),
2145 ],
2146 )
2147 .unwrap(),
2148 ));
2149 let join_type = JoinType::Inner;
2150 let join: PlanRef = LogicalJoin::new(
2151 left.into(),
2152 right.into(),
2153 join_type,
2154 Condition::with_expr(on),
2155 )
2156 .into();
2157
2158 let required_cols = vec![3, 2];
2160 let plan = join.prune_col(&required_cols, &mut ColumnPruningContext::new(join.clone()));
2161
2162 let join = plan.as_logical_join().unwrap();
2164 assert_eq!(join.schema().fields().len(), 2);
2165 assert_eq!(join.schema().fields()[0], fields[3]);
2166 assert_eq!(join.schema().fields()[1], fields[2]);
2167
2168 let expr: ExprImpl = join.on().clone().into();
2169 let call = expr.as_function_call().unwrap();
2170 assert_eq_input_ref!(&call.inputs()[0], 0);
2171 assert_eq_input_ref!(&call.inputs()[1], 2);
2172
2173 let left = join.left();
2174 let left = left.as_logical_values().unwrap();
2175 assert_eq!(left.schema().fields(), &fields[1..3]);
2176 let right = join.right();
2177 let right = right.as_logical_values().unwrap();
2178 assert_eq!(right.schema().fields(), &fields[3..4]);
2179 }
2180
2181 #[tokio::test]
2182 async fn fd_derivation_inner_outer_join() {
2183 let ctx = OptimizerContext::mock().await;
2206 let left = {
2207 let fields: Vec<Field> = vec![
2208 Field::with_name(DataType::Int32, "l0"),
2209 Field::with_name(DataType::Int32, "l1"),
2210 ];
2211 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx.clone());
2212 values
2214 .base
2215 .functional_dependency_mut()
2216 .add_functional_dependency_by_column_indices(&[0], &[1]);
2217 values
2218 };
2219 let right = {
2220 let fields: Vec<Field> = vec![
2221 Field::with_name(DataType::Int32, "r0"),
2222 Field::with_name(DataType::Int32, "r1"),
2223 Field::with_name(DataType::Int32, "r2"),
2224 ];
2225 let mut values = LogicalValues::new(vec![], Schema { fields }, ctx);
2226 values
2228 .base
2229 .functional_dependency_mut()
2230 .add_functional_dependency_by_column_indices(&[0], &[1, 2]);
2231 values
2232 };
2233 let on: ExprImpl = FunctionCall::new(
2235 Type::And,
2236 vec![
2237 FunctionCall::new(
2238 Type::Equal,
2239 vec![
2240 InputRef::new(0, DataType::Int32).into(),
2241 ExprImpl::literal_int(0),
2242 ],
2243 )
2244 .unwrap()
2245 .into(),
2246 FunctionCall::new(
2247 Type::Equal,
2248 vec![
2249 InputRef::new(1, DataType::Int32).into(),
2250 InputRef::new(3, DataType::Int32).into(),
2251 ],
2252 )
2253 .unwrap()
2254 .into(),
2255 ],
2256 )
2257 .unwrap()
2258 .into();
2259 let expected_fd_set = [
2260 (
2261 JoinType::Inner,
2262 [
2263 FunctionalDependency::with_indices(5, &[0], &[1]),
2265 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2267 FunctionalDependency::with_indices(5, &[], &[0]),
2269 FunctionalDependency::with_indices(5, &[1], &[3]),
2271 FunctionalDependency::with_indices(5, &[3], &[1]),
2272 ]
2273 .into_iter()
2274 .collect::<HashSet<_>>(),
2275 ),
2276 (JoinType::FullOuter, HashSet::new()),
2277 (
2278 JoinType::RightOuter,
2279 [
2280 FunctionalDependency::with_indices(5, &[2], &[3, 4]),
2282 ]
2283 .into_iter()
2284 .collect::<HashSet<_>>(),
2285 ),
2286 (
2287 JoinType::LeftOuter,
2288 [
2289 FunctionalDependency::with_indices(5, &[0], &[1]),
2291 ]
2292 .into_iter()
2293 .collect::<HashSet<_>>(),
2294 ),
2295 (
2296 JoinType::LeftSemi,
2297 [
2298 FunctionalDependency::with_indices(2, &[0], &[1]),
2300 ]
2301 .into_iter()
2302 .collect::<HashSet<_>>(),
2303 ),
2304 (
2305 JoinType::LeftAnti,
2306 [
2307 FunctionalDependency::with_indices(2, &[0], &[1]),
2309 ]
2310 .into_iter()
2311 .collect::<HashSet<_>>(),
2312 ),
2313 (
2314 JoinType::RightSemi,
2315 [
2316 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2318 ]
2319 .into_iter()
2320 .collect::<HashSet<_>>(),
2321 ),
2322 (
2323 JoinType::RightAnti,
2324 [
2325 FunctionalDependency::with_indices(3, &[0], &[1, 2]),
2327 ]
2328 .into_iter()
2329 .collect::<HashSet<_>>(),
2330 ),
2331 ];
2332
2333 for (join_type, expected_res) in expected_fd_set {
2334 let join = LogicalJoin::new(
2335 left.clone().into(),
2336 right.clone().into(),
2337 join_type,
2338 Condition::with_expr(on.clone()),
2339 );
2340 let fd_set = join
2341 .functional_dependency()
2342 .as_dependencies()
2343 .iter()
2344 .cloned()
2345 .collect::<HashSet<_>>();
2346 assert_eq!(fd_set, expected_res);
2347 }
2348 }
2349}