1use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use risingwave_common::bitmap::Bitmap;
25use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
26use risingwave_common::metrics::LabelGuardedIntCounter;
27use risingwave_common::row::{OwnedRow, Row, RowExt};
28use risingwave_common::types::{DataType, ScalarImpl};
29use risingwave_common::util::epoch::EpochPair;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::row_serde::OrderedRowSerde;
32use risingwave_common::util::sort_util::OrderType;
33use risingwave_common_estimate_size::EstimateSize;
34use risingwave_storage::StateStore;
35use risingwave_storage::store::PrefetchOptions;
36use risingwave_storage::table::KeyedRow;
37use thiserror_ext::AsReport;
38
39use super::row::{CachedJoinRow, DegreeType, build_degree_row};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::error::StreamExecutorResult;
45use crate::executor::join::row::JoinRow;
46use crate::executor::monitor::StreamingMetrics;
47use crate::executor::{JoinEncoding, StreamExecutorError};
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type HashValueType<E> = Box<JoinEntryState<E>>;
55
56impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
57 fn estimated_heap_size(&self) -> usize {
58 self.as_ref().estimated_heap_size()
59 }
60}
61
62struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
68
69pub(crate) enum CacheResult<E: JoinEncoding> {
70 NeverMatch, Miss, Hit(HashValueType<E>), }
74
75impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
76 fn estimated_heap_size(&self) -> usize {
77 self.0.estimated_heap_size()
78 }
79}
80
81impl<E: JoinEncoding> HashValueWrapper<E> {
82 const MESSAGE: &'static str = "the state should always be `Some`";
83
84 pub fn take(&mut self) -> HashValueType<E> {
86 self.0.take().expect(Self::MESSAGE)
87 }
88}
89
90impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
91 type Target = HashValueType<E>;
92
93 fn deref(&self) -> &Self::Target {
94 self.0.as_ref().expect(Self::MESSAGE)
95 }
96}
97
98impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
99 fn deref_mut(&mut self) -> &mut Self::Target {
100 self.0.as_mut().expect(Self::MESSAGE)
101 }
102}
103
104type JoinHashMapInner<K, E> = ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher>;
105
106pub struct JoinHashMapMetrics {
107 lookup_miss_count: usize,
110 total_lookup_count: usize,
111 insert_cache_miss_count: usize,
113
114 join_lookup_total_count_metric: LabelGuardedIntCounter,
116 join_lookup_miss_count_metric: LabelGuardedIntCounter,
117 join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
118}
119
120impl JoinHashMapMetrics {
121 pub fn new(
122 metrics: &StreamingMetrics,
123 actor_id: ActorId,
124 fragment_id: FragmentId,
125 side: &'static str,
126 join_table_id: TableId,
127 ) -> Self {
128 let actor_id = actor_id.to_string();
129 let fragment_id = fragment_id.to_string();
130 let join_table_id = join_table_id.to_string();
131 let join_lookup_total_count_metric = metrics
132 .join_lookup_total_count
133 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
134 let join_lookup_miss_count_metric = metrics
135 .join_lookup_miss_count
136 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
137 let join_insert_cache_miss_count_metrics = metrics
138 .join_insert_cache_miss_count
139 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
140
141 Self {
142 lookup_miss_count: 0,
143 total_lookup_count: 0,
144 insert_cache_miss_count: 0,
145 join_lookup_total_count_metric,
146 join_lookup_miss_count_metric,
147 join_insert_cache_miss_count_metrics,
148 }
149 }
150
151 pub fn flush(&mut self) {
152 self.join_lookup_total_count_metric
153 .inc_by(self.total_lookup_count as u64);
154 self.join_lookup_miss_count_metric
155 .inc_by(self.lookup_miss_count as u64);
156 self.join_insert_cache_miss_count_metrics
157 .inc_by(self.insert_cache_miss_count as u64);
158 self.total_lookup_count = 0;
159 self.lookup_miss_count = 0;
160 self.insert_cache_miss_count = 0;
161 }
162}
163
164struct InequalityKeyDesc {
166 idx: usize,
167 serializer: OrderedRowSerde,
168}
169
170impl InequalityKeyDesc {
171 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
173 let indices = vec![self.idx];
174 let inequality_key = row.project(&indices);
175 inequality_key.memcmp_serialize(&self.serializer)
176 }
177}
178
179pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
180 inner: JoinHashMapInner<K, E>,
182 join_key_data_types: Vec<DataType>,
184 null_matched: K::Bitmap,
186 pk_serializer: OrderedRowSerde,
188 state: TableInner<S>,
190 degree_state: Option<TableInner<S>>,
221 need_degree_table: bool,
224 pk_contained_in_jk: bool,
226 inequality_key_desc: Option<InequalityKeyDesc>,
228 metrics: JoinHashMapMetrics,
230 _marker: std::marker::PhantomData<E>,
231}
232
233impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
234 pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
235 (&self.state.order_key_indices, &mut self.degree_state)
236 }
237
238 pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
245 &'a mut self,
246 key: &'a K,
247 ) -> StreamExecutorResult<(
248 impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
249 &'a [usize],
250 &'a mut Option<TableInner<S>>,
251 )> {
252 let degree_state = &mut self.degree_state;
253 let (order_key_indices, pk_indices, state_table) = (
254 &self.state.order_key_indices,
255 &self.state.pk_indices,
256 &mut self.state.table,
257 );
258 let degrees = if let Some(degree_state) = degree_state {
259 Some(fetch_degrees(key, &self.join_key_data_types, °ree_state.table).await?)
260 } else {
261 None
262 };
263 let stream = into_stream(
264 &self.join_key_data_types,
265 pk_indices,
266 &self.pk_serializer,
267 state_table,
268 key,
269 degrees,
270 );
271 Ok((stream, order_key_indices, &mut self.degree_state))
272 }
273}
274
275#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
276pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
277 join_key_data_types: &'a [DataType],
278 pk_indices: &'a [usize],
279 pk_serializer: &'a OrderedRowSerde,
280 state_table: &'a StateTable<S>,
281 key: &'a K,
282 degrees: Option<Vec<DegreeType>>,
283) {
284 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
285 let decoded_key = key.deserialize(join_key_data_types)?;
286 let table_iter = state_table
287 .iter_with_prefix_respecting_watermark(&decoded_key, sub_range, PrefetchOptions::default())
288 .await?;
289
290 #[for_await]
291 for (i, entry) in table_iter.enumerate() {
292 let encoded_row = entry?;
293 let encoded_pk = encoded_row
294 .as_ref()
295 .project(pk_indices)
296 .memcmp_serialize(pk_serializer);
297 let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
298 yield (encoded_pk, join_row);
299 }
300}
301
302async fn fetch_degrees<K: HashKey, S: StateStore>(
326 key: &K,
327 join_key_data_types: &[DataType],
328 degree_state_table: &StateTable<S>,
329) -> StreamExecutorResult<Vec<DegreeType>> {
330 let key = key.deserialize(join_key_data_types)?;
331 let mut degrees = vec![];
332 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
333 let table_iter = degree_state_table
334 .iter_with_prefix_respecting_watermark(key, sub_range, PrefetchOptions::default())
335 .await?;
336 let degree_col_idx = degree_col_idx_in_row(degree_state_table);
337 #[for_await]
338 for entry in table_iter {
339 let degree_row = entry?;
340 debug_assert!(
341 degree_row.len() > degree_col_idx,
342 "degree row should have at least pk_len + 1 columns"
343 );
344 let degree_i64 = degree_row
345 .datum_at(degree_col_idx)
346 .expect("degree should not be NULL");
347 degrees.push(degree_i64.into_int64() as u64);
348 }
349 Ok(degrees)
350}
351
352fn degree_col_idx_in_row<S: StateStore>(degree_state_table: &StateTable<S>) -> usize {
353 let degree_col_idx = degree_state_table.pk_indices().len();
355 match degree_state_table.value_indices() {
356 Some(value_indices) => value_indices
357 .iter()
358 .position(|idx| *idx == degree_col_idx)
359 .expect("degree column should be included in value indices"),
360 None => degree_col_idx,
361 }
362}
363
364pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
368 order_key_indices: &[usize],
369 degree_state: &mut TableInner<S>,
370 matched_row: &mut JoinRow<impl Row>,
371) {
372 let inequality_idx = degree_state.degree_inequality_idx;
373 let old_degree_row = build_degree_row(
374 order_key_indices,
375 matched_row.degree,
376 inequality_idx,
377 &matched_row.row,
378 );
379 if INCREMENT {
380 matched_row.degree += 1;
381 } else {
382 matched_row.degree -= 1;
384 }
385 let new_degree_row = build_degree_row(
386 order_key_indices,
387 matched_row.degree,
388 inequality_idx,
389 &matched_row.row,
390 );
391 degree_state.table.update(old_degree_row, new_degree_row);
392}
393
394pub struct TableInner<S: StateStore> {
395 pk_indices: Vec<usize>,
397 join_key_indices: Vec<usize>,
399 order_key_indices: Vec<usize>,
404 pub(crate) degree_inequality_idx: Option<usize>,
408 pub(crate) table: StateTable<S>,
409}
410
411impl<S: StateStore> TableInner<S> {
412 pub fn new(
413 pk_indices: Vec<usize>,
414 join_key_indices: Vec<usize>,
415 table: StateTable<S>,
416 degree_inequality_idx: Option<usize>,
417 ) -> Self {
418 let order_key_indices = table.pk_indices().to_vec();
419 Self {
420 pk_indices,
421 join_key_indices,
422 order_key_indices,
423 degree_inequality_idx,
424 table,
425 }
426 }
427
428 fn error_context(&self, row: &impl Row) -> String {
429 let pk = row.project(&self.pk_indices);
430 let jk = row.project(&self.join_key_indices);
431 format!(
432 "join key: {}, pk: {}, row: {}, state_table_id: {}",
433 jk.display(),
434 pk.display(),
435 row.display(),
436 self.table.table_id()
437 )
438 }
439}
440
441impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
442 #[allow(clippy::too_many_arguments)]
444 pub fn new(
445 watermark_sequence: AtomicU64Ref,
446 join_key_data_types: Vec<DataType>,
447 state_join_key_indices: Vec<usize>,
448 state_all_data_types: Vec<DataType>,
449 state_table: StateTable<S>,
450 state_pk_indices: Vec<usize>,
451 degree_state: Option<TableInner<S>>,
452 null_matched: K::Bitmap,
453 pk_contained_in_jk: bool,
454 inequality_key_idx: Option<usize>,
455 metrics: Arc<StreamingMetrics>,
456 actor_id: ActorId,
457 fragment_id: FragmentId,
458 side: &'static str,
459 ) -> Self {
460 let pk_data_types = state_pk_indices
462 .iter()
463 .map(|i| state_all_data_types[*i].clone())
464 .collect();
465 let pk_serializer = OrderedRowSerde::new(
466 pk_data_types,
467 vec![OrderType::ascending(); state_pk_indices.len()],
468 );
469
470 let inequality_key_desc = inequality_key_idx.map(|idx| {
471 let serializer = OrderedRowSerde::new(
472 vec![state_all_data_types[idx].clone()],
473 vec![OrderType::ascending()],
474 );
475 InequalityKeyDesc { idx, serializer }
476 });
477
478 let join_table_id = state_table.table_id();
479 let state = TableInner {
480 pk_indices: state_pk_indices,
481 join_key_indices: state_join_key_indices,
482 order_key_indices: state_table.pk_indices().to_vec(),
483 degree_inequality_idx: inequality_key_idx,
484 table: state_table,
485 };
486
487 let need_degree_table = degree_state.is_some();
488
489 let metrics_info = MetricsInfo::new(
490 metrics.clone(),
491 join_table_id,
492 actor_id,
493 format!("hash join {}", side),
494 );
495
496 let cache = ManagedLruCache::unbounded_with_hasher(
497 watermark_sequence,
498 metrics_info,
499 PrecomputedBuildHasher,
500 );
501
502 Self {
503 inner: cache,
504 join_key_data_types,
505 null_matched,
506 pk_serializer,
507 state,
508 degree_state,
509 need_degree_table,
510 pk_contained_in_jk,
511 inequality_key_desc,
512 metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
513 _marker: std::marker::PhantomData,
514 }
515 }
516
517 pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
518 self.state.table.init_epoch(epoch).await?;
519 if let Some(degree_state) = &mut self.degree_state {
520 degree_state.table.init_epoch(epoch).await?;
521 }
522 Ok(())
523 }
524}
525
526impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
527 pub async fn post_yield_barrier(
528 self,
529 vnode_bitmap: Option<Arc<Bitmap>>,
530 ) -> StreamExecutorResult<Option<bool>> {
531 let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
532 if let Some(degree_state) = self.degree_state {
533 let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
534 }
535 let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
536 if cache_may_stale.unwrap_or(false) {
537 self.inner.clear();
538 }
539 Ok(cache_may_stale)
540 }
541}
542impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
543 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
544 self.state.table.update_watermark(watermark.clone());
546 if let Some(degree_state) = &mut self.degree_state {
547 degree_state.table.update_watermark(watermark);
548 }
549 }
550
551 pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
560 self.metrics.total_lookup_count += 1;
561 if self.inner.contains(key) {
562 tracing::trace!("hit cache for join key: {:?}", key);
563 let mut state = self.inner.peek_mut(key).expect("checked contains");
566 CacheResult::Hit(state.take())
567 } else {
568 self.metrics.lookup_miss_count += 1;
569 tracing::trace!("miss cache for join key: {:?}", key);
570 CacheResult::Miss
571 }
572 }
573
574 pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
583 self.metrics.total_lookup_count += 1;
584 let state = if self.inner.contains(key) {
585 let mut state = self.inner.peek_mut(key).unwrap();
588 state.take()
589 } else {
590 self.metrics.lookup_miss_count += 1;
591 self.fetch_cached_state(key).await?.into()
592 };
593 Ok(state)
594 }
595
596 async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
599 let key = key.deserialize(&self.join_key_data_types)?;
600
601 let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
602
603 if self.need_degree_table {
604 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
605 &(Bound::Unbounded, Bound::Unbounded);
606 let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
607 &key,
608 sub_range,
609 PrefetchOptions::default(),
610 );
611 let degree_state = self.degree_state.as_ref().unwrap();
612 let degree_col_idx = degree_col_idx_in_row(°ree_state.table);
613 let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
614 &key,
615 sub_range,
616 PrefetchOptions::default(),
617 );
618
619 let (table_iter, degree_table_iter) =
620 try_join(table_iter_fut, degree_table_iter_fut).await?;
621
622 let mut pinned_table_iter = std::pin::pin!(table_iter);
623 let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
624
625 let mut rows = vec![];
628 let mut degree_rows = vec![];
629 let mut inconsistency_happened = false;
630 loop {
631 let (row, degree_row) =
632 join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
633 let (row, degree_row) = match (row, degree_row) {
634 (None, None) => break,
635 (None, Some(_)) => {
636 inconsistency_happened = true;
637 consistency_panic!(
638 "mismatched row and degree table of join key: {:?}, degree table has more rows",
639 &key
640 );
641 break;
642 }
643 (Some(_), None) => {
644 inconsistency_happened = true;
645 consistency_panic!(
646 "mismatched row and degree table of join key: {:?}, input table has more rows",
647 &key
648 );
649 break;
650 }
651 (Some(r), Some(d)) => (r, d),
652 };
653
654 let row = row?;
655 let degree_row = degree_row?;
656 rows.push(row);
657 degree_rows.push(degree_row);
658 }
659
660 if inconsistency_happened {
661 assert_ne!(rows.len(), degree_rows.len());
663
664 let row_iter = stream::iter(rows.into_iter()).peekable();
665 let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
666 pin_mut!(row_iter);
667 pin_mut!(degree_row_iter);
668
669 loop {
670 match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
671 (None, _) | (_, None) => break,
672 (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
673 Ordering::Greater => {
674 degree_row_iter.next().await;
675 }
676 Ordering::Less => {
677 row_iter.next().await;
678 }
679 Ordering::Equal => {
680 let row =
681 row_iter.next().await.expect("we matched some(row) above");
682 let degree_row = degree_row_iter
683 .next()
684 .await
685 .expect("we matched some(degree_row) above");
686 let pk = row
687 .as_ref()
688 .project(&self.state.pk_indices)
689 .memcmp_serialize(&self.pk_serializer);
690 let degree_i64 = degree_row
691 .datum_at(degree_col_idx)
692 .expect("degree should not be NULL");
693 let inequality_key = self
694 .inequality_key_desc
695 .as_ref()
696 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
697 entry_state
698 .insert(
699 pk,
700 E::encode(&JoinRow::new(
701 row.row(),
702 degree_i64.into_int64() as u64,
703 )),
704 inequality_key,
705 )
706 .with_context(|| self.state.error_context(row.row()))?;
707 }
708 },
709 }
710 }
711 } else {
712 assert_eq!(rows.len(), degree_rows.len());
717
718 #[for_await]
719 for (row, degree_row) in
720 stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
721 {
722 let row: KeyedRow<_> = row;
723 let degree_row: KeyedRow<_> = degree_row;
724
725 let pk1 = row.key();
726 let pk2 = degree_row.key();
727 debug_assert_eq!(
728 pk1, pk2,
729 "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
730 );
731 let pk = row
732 .as_ref()
733 .project(&self.state.pk_indices)
734 .memcmp_serialize(&self.pk_serializer);
735 let inequality_key = self
736 .inequality_key_desc
737 .as_ref()
738 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
739 let degree_i64 = degree_row
740 .datum_at(degree_col_idx)
741 .expect("degree should not be NULL");
742 entry_state
743 .insert(
744 pk,
745 E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
746 inequality_key,
747 )
748 .with_context(|| self.state.error_context(row.row()))?;
749 }
750 }
751 } else {
752 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
753 &(Bound::Unbounded, Bound::Unbounded);
754 let table_iter = self
755 .state
756 .table
757 .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
758 .await?;
759
760 #[for_await]
761 for entry in table_iter {
762 let row: KeyedRow<_> = entry?;
763 let pk = row
764 .as_ref()
765 .project(&self.state.pk_indices)
766 .memcmp_serialize(&self.pk_serializer);
767 let inequality_key = self
768 .inequality_key_desc
769 .as_ref()
770 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
771 entry_state
772 .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
773 .with_context(|| self.state.error_context(row.row()))?;
774 }
775 };
776
777 Ok(entry_state)
778 }
779
780 pub async fn flush(
781 &mut self,
782 epoch: EpochPair,
783 ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
784 self.metrics.flush();
785 let state_post_commit = self.state.table.commit(epoch).await?;
786 let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
787 Some(degree_state.table.commit(epoch).await?)
788 } else {
789 None
790 };
791 Ok(JoinHashMapPostCommit {
792 state: state_post_commit,
793 degree_state: degree_state_post_commit,
794 inner: &mut self.inner,
795 })
796 }
797
798 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
799 self.state.table.try_flush().await?;
800 if let Some(degree_state) = &mut self.degree_state {
801 degree_state.table.try_flush().await?;
802 }
803 Ok(())
804 }
805
806 pub fn insert_handle_degree(
807 &mut self,
808 key: &K,
809 value: JoinRow<impl Row>,
810 ) -> StreamExecutorResult<()> {
811 if self.need_degree_table {
812 self.insert(key, value)
813 } else {
814 self.insert_row(key, value.row)
815 }
816 }
817
818 pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
820 let pk = self.serialize_pk_from_row(&value.row);
821
822 let inequality_key = self
823 .inequality_key_desc
824 .as_ref()
825 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
826
827 if self.inner.contains(key) {
830 let mut entry = self.inner.get_mut(key).expect("checked contains");
832 entry
833 .insert(pk, E::encode(&value), inequality_key)
834 .with_context(|| self.state.error_context(&value.row))?;
835 } else if self.pk_contained_in_jk {
836 self.metrics.insert_cache_miss_count += 1;
838 let mut entry: JoinEntryState<E> = JoinEntryState::default();
839 entry
840 .insert(pk, E::encode(&value), inequality_key)
841 .with_context(|| self.state.error_context(&value.row))?;
842 self.update_state(key, entry.into());
843 }
844
845 if let Some(degree_state) = self.degree_state.as_mut() {
847 let (row, degree) = value.to_table_rows(
848 &self.state.order_key_indices,
849 degree_state.degree_inequality_idx,
850 );
851 self.state.table.insert(row);
852 degree_state.table.insert(degree);
853 } else {
854 self.state.table.insert(value.row);
855 }
856 Ok(())
857 }
858
859 pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
862 let join_row = JoinRow::new(&value, 0);
863 self.insert(key, join_row)?;
864 Ok(())
865 }
866
867 pub fn delete_row_in_mem(&mut self, key: &K, value: &impl Row) -> StreamExecutorResult<()> {
868 if let Some(mut entry) = self.inner.get_mut(key) {
869 let pk = (&value)
870 .project(&self.state.pk_indices)
871 .memcmp_serialize(&self.pk_serializer);
872
873 let inequality_key = self
874 .inequality_key_desc
875 .as_ref()
876 .map(|desc| desc.serialize_inequal_key_from_row(value));
877 entry
878 .remove(pk, inequality_key.as_ref())
879 .with_context(|| self.state.error_context(&value))?;
880 }
881 Ok(())
882 }
883
884 pub fn delete_handle_degree(
885 &mut self,
886 key: &K,
887 value: JoinRow<impl Row>,
888 ) -> StreamExecutorResult<()> {
889 if self.need_degree_table {
890 self.delete(key, value)
891 } else {
892 self.delete_row(key, value.row)
893 }
894 }
895
896 pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
898 self.delete_row_in_mem(key, &value.row)?;
899
900 let degree_state = self.degree_state.as_mut().expect("degree table missing");
902 let (row, degree) = value.to_table_rows(
903 &self.state.order_key_indices,
904 degree_state.degree_inequality_idx,
905 );
906 self.state.table.delete(row);
907 degree_state.table.delete(degree);
908 Ok(())
909 }
910
911 pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
914 self.delete_row_in_mem(key, &value)?;
915
916 self.state.table.delete(value);
918 Ok(())
919 }
920
921 pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
923 self.inner.put(key.clone(), HashValueWrapper(Some(state)));
924 }
925
926 pub fn evict(&mut self) {
928 self.inner.evict();
929 }
930
931 pub fn entry_count(&self) -> usize {
933 self.inner.len()
934 }
935
936 pub fn null_matched(&self) -> &K::Bitmap {
937 &self.null_matched
938 }
939
940 pub fn table_id(&self) -> TableId {
941 self.state.table.table_id()
942 }
943
944 pub fn join_key_data_types(&self) -> &[DataType] {
945 &self.join_key_data_types
946 }
947
948 pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
952 let desc = self
953 .inequality_key_desc
954 .as_ref()
955 .expect("inequality key desc missing");
956 row.datum_at(desc.idx).is_none()
957 }
958
959 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
963 self.inequality_key_desc
964 .as_ref()
965 .expect("inequality key desc missing")
966 .serialize_inequal_key_from_row(&row)
967 }
968
969 pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
970 row.project(&self.state.pk_indices)
971 .memcmp_serialize(&self.pk_serializer)
972 }
973}
974
975#[must_use]
976pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
977 state: StateTablePostCommit<'a, S>,
978 degree_state: Option<StateTablePostCommit<'a, S>>,
979 inner: &'a mut JoinHashMapInner<K, E>,
980}
981
982use risingwave_common::catalog::TableId;
983use risingwave_common_estimate_size::KvSize;
984use thiserror::Error;
985
986use super::*;
987use crate::executor::prelude::{Stream, try_stream};
988
989#[derive(Default)]
995pub struct JoinEntryState<E: JoinEncoding> {
996 cached: JoinRowSet<PkType, E::EncodedRow>,
998 inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
1000 kv_heap_size: KvSize,
1001}
1002
1003impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
1004 fn estimated_heap_size(&self) -> usize {
1005 self.kv_heap_size.size()
1008 }
1009}
1010
1011#[derive(Error, Debug)]
1012pub enum JoinEntryError {
1013 #[error("double inserting a join state entry")]
1014 Occupied,
1015 #[error("removing a join state entry but it is not in the cache")]
1016 Remove,
1017 #[error("retrieving a pk from the inequality index but it is not in the cache")]
1018 InequalIndex,
1019}
1020
1021impl<E: JoinEncoding> JoinEntryState<E> {
1022 pub fn insert(
1024 &mut self,
1025 key: PkType,
1026 value: E::EncodedRow,
1027 inequality_key: Option<InequalKeyType>,
1028 ) -> Result<&mut E::EncodedRow, JoinEntryError> {
1029 let mut removed = false;
1030 if !enable_strict_consistency() {
1031 if let Some(old_value) = self.cached.remove(&key) {
1033 if let Some(inequality_key) = inequality_key.as_ref() {
1034 self.remove_pk_from_inequality_index(&key, inequality_key);
1035 }
1036 self.kv_heap_size.sub(&key, &old_value);
1037 removed = true;
1038 }
1039 }
1040
1041 self.kv_heap_size.add(&key, &value);
1042
1043 if let Some(inequality_key) = inequality_key {
1044 self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1045 }
1046 let ret = self.cached.try_insert(key.clone(), value);
1047
1048 if !enable_strict_consistency() {
1049 assert!(ret.is_ok(), "we have removed existing entry, if any");
1050 if removed {
1051 consistency_error!(?key, "double inserting a join state entry");
1053 }
1054 }
1055
1056 ret.map_err(|_| JoinEntryError::Occupied)
1057 }
1058
1059 pub fn remove(
1061 &mut self,
1062 pk: PkType,
1063 inequality_key: Option<&InequalKeyType>,
1064 ) -> Result<(), JoinEntryError> {
1065 if let Some(value) = self.cached.remove(&pk) {
1066 self.kv_heap_size.sub(&pk, &value);
1067 if let Some(inequality_key) = inequality_key {
1068 self.remove_pk_from_inequality_index(&pk, inequality_key);
1069 }
1070 Ok(())
1071 } else if enable_strict_consistency() {
1072 Err(JoinEntryError::Remove)
1073 } else {
1074 consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1075 Ok(())
1076 }
1077 }
1078
1079 fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1080 if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1081 if pk_set.remove(pk).is_none() {
1082 if enable_strict_consistency() {
1083 panic!("removing a pk that it not in the inequality index");
1084 } else {
1085 consistency_error!(?pk, "removing a pk that it not in the inequality index");
1086 };
1087 } else {
1088 self.kv_heap_size.sub(pk, &());
1089 }
1090 if pk_set.is_empty() {
1091 self.inequality_index.remove(inequality_key);
1092 }
1093 }
1094 }
1095
1096 fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1097 if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1098 let pk_size = pk.estimated_size();
1099 if pk_set.try_insert(pk, ()).is_err() {
1100 if enable_strict_consistency() {
1101 panic!("inserting a pk that it already in the inequality index");
1102 } else {
1103 consistency_error!("inserting a pk that it already in the inequality index");
1104 };
1105 } else {
1106 self.kv_heap_size.add_size(pk_size);
1107 }
1108 } else {
1109 let mut pk_set = JoinRowSet::default();
1110 pk_set.try_insert(pk, ()).expect("pk set should be empty");
1111 self.inequality_index
1112 .try_insert(inequality_key, pk_set)
1113 .expect("pk set should be empty");
1114 }
1115 }
1116
1117 pub fn get(
1118 &self,
1119 pk: &PkType,
1120 data_types: &[DataType],
1121 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1122 self.cached
1123 .get(pk)
1124 .map(|encoded| encoded.decode(data_types))
1125 }
1126
1127 pub fn values_mut<'a>(
1133 &'a mut self,
1134 data_types: &'a [DataType],
1135 ) -> impl Iterator<
1136 Item = (
1137 &'a mut E::EncodedRow,
1138 StreamExecutorResult<JoinRow<E::DecodedRow>>,
1139 ),
1140 > + 'a {
1141 self.cached.values_mut().map(|encoded| {
1142 let decoded = encoded.decode(data_types);
1143 (encoded, decoded)
1144 })
1145 }
1146
1147 pub fn len(&self) -> usize {
1148 self.cached.len()
1149 }
1150
1151 pub fn range_by_inequality<'a, R>(
1153 &'a self,
1154 range: R,
1155 data_types: &'a [DataType],
1156 ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<E::DecodedRow>>> + 'a
1157 where
1158 R: RangeBounds<InequalKeyType> + 'a,
1159 {
1160 self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1161 pk_set
1162 .keys()
1163 .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1164 })
1165 }
1166
1167 pub fn upper_bound_by_inequality<'a>(
1169 &'a self,
1170 bound: Bound<&InequalKeyType>,
1171 data_types: &'a [DataType],
1172 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1173 if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1174 if let Some(pk) = pk_set.first_key_sorted() {
1175 self.get_by_indexed_pk(pk, data_types)
1176 } else {
1177 panic!("pk set for a index record must has at least one element");
1178 }
1179 } else {
1180 None
1181 }
1182 }
1183
1184 pub fn get_by_indexed_pk(
1185 &self,
1186 pk: &PkType,
1187 data_types: &[DataType],
1188 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>>
1189where {
1190 if let Some(value) = self.cached.get(pk) {
1191 Some(value.decode(data_types))
1192 } else if enable_strict_consistency() {
1193 Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1194 } else {
1195 consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1196 None
1197 }
1198 }
1199
1200 pub fn lower_bound_by_inequality<'a>(
1202 &'a self,
1203 bound: Bound<&InequalKeyType>,
1204 data_types: &'a [DataType],
1205 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1206 if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1207 if let Some(pk) = pk_set.first_key_sorted() {
1208 self.get_by_indexed_pk(pk, data_types)
1209 } else {
1210 panic!("pk set for a index record must has at least one element");
1211 }
1212 } else {
1213 None
1214 }
1215 }
1216
1217 pub fn get_first_by_inequality<'a>(
1218 &'a self,
1219 inequality_key: &InequalKeyType,
1220 data_types: &'a [DataType],
1221 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1222 if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1223 if let Some(pk) = pk_set.first_key_sorted() {
1224 self.get_by_indexed_pk(pk, data_types)
1225 } else {
1226 panic!("pk set for a index record must has at least one element");
1227 }
1228 } else {
1229 None
1230 }
1231 }
1232
1233 pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1234 &self.inequality_index
1235 }
1236}
1237
1238#[cfg(test)]
1239mod tests {
1240 use itertools::Itertools;
1241 use risingwave_common::array::*;
1242 use risingwave_common::types::ScalarRefImpl;
1243 use risingwave_common::util::iter_util::ZipEqDebug;
1244
1245 use super::*;
1246 use crate::executor::MemoryEncoding;
1247
1248 fn insert_chunk<E: JoinEncoding>(
1249 managed_state: &mut JoinEntryState<E>,
1250 pk_indices: &[usize],
1251 col_types: &[DataType],
1252 inequality_key_idx: Option<usize>,
1253 data_chunk: &DataChunk,
1254 ) {
1255 let pk_col_type = pk_indices
1256 .iter()
1257 .map(|idx| col_types[*idx].clone())
1258 .collect_vec();
1259 let pk_serializer =
1260 OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1261 let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1262 let inequality_key_serializer = inequality_key_type
1263 .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1264 for row_ref in data_chunk.rows() {
1265 let row: OwnedRow = row_ref.into_owned_row();
1266 let value_indices = (0..row.len() - 1).collect_vec();
1267 let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1268 let pk = OwnedRow::new(pk)
1270 .project(&value_indices)
1271 .memcmp_serialize(&pk_serializer);
1272 let inequality_key = inequality_key_idx.map(|idx| {
1273 (&row)
1274 .project(&[idx])
1275 .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1276 });
1277 let join_row = JoinRow { row, degree: 0 };
1278 managed_state
1279 .insert(pk, E::encode(&join_row), inequality_key)
1280 .unwrap();
1281 }
1282 }
1283
1284 fn check<E: JoinEncoding>(
1285 managed_state: &mut JoinEntryState<E>,
1286 col_types: &[DataType],
1287 col1: &[i64],
1288 col2: &[i64],
1289 ) {
1290 for ((_, matched_row), (d1, d2)) in managed_state
1291 .values_mut(col_types)
1292 .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1293 {
1294 let matched_row = matched_row.unwrap();
1295 assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
1296 assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
1297 assert_eq!(matched_row.degree, 0);
1298 }
1299 }
1300
1301 #[tokio::test]
1302 async fn test_managed_join_state() {
1303 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1304 let col_types = vec![DataType::Int64, DataType::Int64];
1305 let pk_indices = [0];
1306
1307 let col1 = [3, 2, 1];
1308 let col2 = [4, 5, 6];
1309 let data_chunk1 = DataChunk::from_pretty(
1310 "I I
1311 3 4
1312 2 5
1313 1 6",
1314 );
1315
1316 insert_chunk::<MemoryEncoding>(
1318 &mut managed_state,
1319 &pk_indices,
1320 &col_types,
1321 None,
1322 &data_chunk1,
1323 );
1324 check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1325
1326 let col1 = [1, 2, 3, 4, 5];
1328 let col2 = [6, 5, 4, 9, 8];
1329 let data_chunk2 = DataChunk::from_pretty(
1330 "I I
1331 5 8
1332 4 9",
1333 );
1334 insert_chunk(
1335 &mut managed_state,
1336 &pk_indices,
1337 &col_types,
1338 None,
1339 &data_chunk2,
1340 );
1341 check(&mut managed_state, &col_types, &col1, &col2);
1342 }
1343
1344 #[tokio::test]
1345 async fn test_managed_join_state_w_inequality_index() {
1346 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1347 let col_types = vec![DataType::Int64, DataType::Int64];
1348 let pk_indices = [0];
1349 let inequality_key_idx = Some(1);
1350 let inequality_key_serializer =
1351 OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1352
1353 let col1 = [3, 2, 1];
1354 let col2 = [4, 5, 5];
1355 let data_chunk1 = DataChunk::from_pretty(
1356 "I I
1357 3 4
1358 2 5
1359 1 5",
1360 );
1361
1362 insert_chunk(
1364 &mut managed_state,
1365 &pk_indices,
1366 &col_types,
1367 inequality_key_idx,
1368 &data_chunk1,
1369 );
1370 check(&mut managed_state, &col_types, &col1, &col2);
1371 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1372 .memcmp_serialize(&inequality_key_serializer);
1373 let row = managed_state
1374 .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1375 .unwrap()
1376 .unwrap();
1377 assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1378 let row = managed_state
1379 .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1380 .unwrap()
1381 .unwrap();
1382 assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1383
1384 let col1 = [1, 2, 3, 4, 5];
1386 let col2 = [5, 5, 4, 4, 8];
1387 let data_chunk2 = DataChunk::from_pretty(
1388 "I I
1389 5 8
1390 4 4",
1391 );
1392 insert_chunk(
1393 &mut managed_state,
1394 &pk_indices,
1395 &col_types,
1396 inequality_key_idx,
1397 &data_chunk2,
1398 );
1399 check(&mut managed_state, &col_types, &col1, &col2);
1400
1401 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1402 .memcmp_serialize(&inequality_key_serializer);
1403 let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1404 assert!(row.is_none());
1405 }
1406}