1use std::alloc::Global;
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
25use risingwave_common::bitmap::Bitmap;
26use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
27use risingwave_common::metrics::LabelGuardedIntCounter;
28use risingwave_common::row::{OwnedRow, Row, RowExt, once};
29use risingwave_common::types::{DataType, ScalarImpl};
30use risingwave_common::util::epoch::EpochPair;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common::util::row_serde::OrderedRowSerde;
33use risingwave_common::util::sort_util::OrderType;
34use risingwave_common_estimate_size::EstimateSize;
35use risingwave_storage::StateStore;
36use risingwave_storage::store::PrefetchOptions;
37use thiserror_ext::AsReport;
38
39use super::row::{DegreeType, EncodedJoinRow};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::StreamExecutorError;
45use crate::executor::error::StreamExecutorResult;
46use crate::executor::join::row::JoinRow;
47use crate::executor::monitor::StreamingMetrics;
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type StateValueType = EncodedJoinRow;
55pub type HashValueType = Box<JoinEntryState>;
56
57impl EstimateSize for HashValueType {
58 fn estimated_heap_size(&self) -> usize {
59 self.as_ref().estimated_heap_size()
60 }
61}
62
63struct HashValueWrapper(Option<HashValueType>);
69
70pub(crate) enum CacheResult {
71 NeverMatch, Miss, Hit(HashValueType), }
75
76impl EstimateSize for HashValueWrapper {
77 fn estimated_heap_size(&self) -> usize {
78 self.0.estimated_heap_size()
79 }
80}
81
82impl HashValueWrapper {
83 const MESSAGE: &'static str = "the state should always be `Some`";
84
85 pub fn take(&mut self) -> HashValueType {
87 self.0.take().expect(Self::MESSAGE)
88 }
89}
90
91impl Deref for HashValueWrapper {
92 type Target = HashValueType;
93
94 fn deref(&self) -> &Self::Target {
95 self.0.as_ref().expect(Self::MESSAGE)
96 }
97}
98
99impl DerefMut for HashValueWrapper {
100 fn deref_mut(&mut self) -> &mut Self::Target {
101 self.0.as_mut().expect(Self::MESSAGE)
102 }
103}
104
105type JoinHashMapInner<K> =
106 ManagedLruCache<K, HashValueWrapper, PrecomputedBuildHasher, SharedStatsAlloc<Global>>;
107
108pub struct JoinHashMapMetrics {
109 lookup_miss_count: usize,
112 total_lookup_count: usize,
113 insert_cache_miss_count: usize,
115
116 join_lookup_total_count_metric: LabelGuardedIntCounter<4>,
118 join_lookup_miss_count_metric: LabelGuardedIntCounter<4>,
119 join_insert_cache_miss_count_metrics: LabelGuardedIntCounter<4>,
120}
121
122impl JoinHashMapMetrics {
123 pub fn new(
124 metrics: &StreamingMetrics,
125 actor_id: ActorId,
126 fragment_id: FragmentId,
127 side: &'static str,
128 join_table_id: u32,
129 ) -> Self {
130 let actor_id = actor_id.to_string();
131 let fragment_id = fragment_id.to_string();
132 let join_table_id = join_table_id.to_string();
133 let join_lookup_total_count_metric = metrics
134 .join_lookup_total_count
135 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
136 let join_lookup_miss_count_metric = metrics
137 .join_lookup_miss_count
138 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
139 let join_insert_cache_miss_count_metrics = metrics
140 .join_insert_cache_miss_count
141 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
142
143 Self {
144 lookup_miss_count: 0,
145 total_lookup_count: 0,
146 insert_cache_miss_count: 0,
147 join_lookup_total_count_metric,
148 join_lookup_miss_count_metric,
149 join_insert_cache_miss_count_metrics,
150 }
151 }
152
153 pub fn flush(&mut self) {
154 self.join_lookup_total_count_metric
155 .inc_by(self.total_lookup_count as u64);
156 self.join_lookup_miss_count_metric
157 .inc_by(self.lookup_miss_count as u64);
158 self.join_insert_cache_miss_count_metrics
159 .inc_by(self.insert_cache_miss_count as u64);
160 self.total_lookup_count = 0;
161 self.lookup_miss_count = 0;
162 self.insert_cache_miss_count = 0;
163 }
164}
165
166struct InequalityKeyDesc {
168 idx: usize,
169 serializer: OrderedRowSerde,
170}
171
172impl InequalityKeyDesc {
173 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
175 let indices = vec![self.idx];
176 let inequality_key = row.project(&indices);
177 inequality_key.memcmp_serialize(&self.serializer)
178 }
179}
180
181pub struct JoinHashMap<K: HashKey, S: StateStore> {
182 inner: JoinHashMapInner<K>,
184 join_key_data_types: Vec<DataType>,
186 null_matched: K::Bitmap,
188 pk_serializer: OrderedRowSerde,
190 state: TableInner<S>,
192 degree_state: Option<TableInner<S>>,
223 need_degree_table: bool,
226 pk_contained_in_jk: bool,
228 inequality_key_desc: Option<InequalityKeyDesc>,
230 metrics: JoinHashMapMetrics,
232}
233
234impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
235 pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
236 (&self.state.order_key_indices, &mut self.degree_state)
237 }
238
239 pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
246 &'a mut self,
247 key: &'a K,
248 ) -> StreamExecutorResult<(
249 impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
250 &'a [usize],
251 &'a mut Option<TableInner<S>>,
252 )> {
253 let degree_state = &mut self.degree_state;
254 let (order_key_indices, pk_indices, state_table) = (
255 &self.state.order_key_indices,
256 &self.state.pk_indices,
257 &mut self.state.table,
258 );
259 let degrees = if let Some(degree_state) = degree_state {
260 Some(fetch_degrees(key, &self.join_key_data_types, °ree_state.table).await?)
261 } else {
262 None
263 };
264 let stream = into_stream(
265 &self.join_key_data_types,
266 pk_indices,
267 &self.pk_serializer,
268 state_table,
269 key,
270 degrees,
271 );
272 Ok((stream, order_key_indices, &mut self.degree_state))
273 }
274}
275
276#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
277pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
278 join_key_data_types: &'a [DataType],
279 pk_indices: &'a [usize],
280 pk_serializer: &'a OrderedRowSerde,
281 state_table: &'a StateTable<S>,
282 key: &'a K,
283 degrees: Option<Vec<DegreeType>>,
284) {
285 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
286 let decoded_key = key.deserialize(join_key_data_types)?;
287 let table_iter = state_table
288 .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
289 .await?;
290
291 #[for_await]
292 for (i, entry) in table_iter.enumerate() {
293 let encoded_row = entry?;
294 let encoded_pk = encoded_row
295 .as_ref()
296 .project(pk_indices)
297 .memcmp_serialize(pk_serializer);
298 let join_row = JoinRow::new(
299 encoded_row.into_owned_row(),
300 degrees.as_ref().map_or(0, |d| d[i]),
301 );
302 yield (encoded_pk, join_row);
303 }
304}
305
306async fn fetch_degrees<K: HashKey, S: StateStore>(
330 key: &K,
331 join_key_data_types: &[DataType],
332 degree_state_table: &StateTable<S>,
333) -> StreamExecutorResult<Vec<DegreeType>> {
334 let key = key.deserialize(join_key_data_types)?;
335 let mut degrees = vec![];
336 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
337 let table_iter = degree_state_table
338 .iter_with_prefix(key, sub_range, PrefetchOptions::default())
339 .await
340 .unwrap();
341 #[for_await]
342 for entry in table_iter {
343 let degree_row = entry?;
344 let degree_i64 = degree_row
345 .datum_at(degree_row.len() - 1)
346 .expect("degree should not be NULL");
347 degrees.push(degree_i64.into_int64() as u64);
348 }
349 Ok(degrees)
350}
351
352pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
356 order_key_indices: &[usize],
357 degree_state: &mut TableInner<S>,
358 matched_row: &mut JoinRow<OwnedRow>,
359) {
360 let old_degree_row = matched_row
361 .row
362 .as_ref()
363 .project(order_key_indices)
364 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
365 if INCREMENT {
366 matched_row.degree += 1;
367 } else {
368 matched_row.degree -= 1;
370 }
371 let new_degree_row = matched_row
372 .row
373 .as_ref()
374 .project(order_key_indices)
375 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
376 degree_state.table.update(old_degree_row, new_degree_row);
377}
378
379pub struct TableInner<S: StateStore> {
380 pk_indices: Vec<usize>,
382 join_key_indices: Vec<usize>,
384 order_key_indices: Vec<usize>,
389 pub(crate) table: StateTable<S>,
390}
391
392impl<S: StateStore> TableInner<S> {
393 pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
394 let order_key_indices = table.pk_indices().to_vec();
395 Self {
396 pk_indices,
397 join_key_indices,
398 order_key_indices,
399 table,
400 }
401 }
402
403 fn error_context(&self, row: &impl Row) -> String {
404 let pk = row.project(&self.pk_indices);
405 let jk = row.project(&self.join_key_indices);
406 format!(
407 "join key: {}, pk: {}, row: {}, state_table_id: {}",
408 jk.display(),
409 pk.display(),
410 row.display(),
411 self.table.table_id()
412 )
413 }
414}
415
416impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
417 #[allow(clippy::too_many_arguments)]
419 pub fn new(
420 watermark_sequence: AtomicU64Ref,
421 join_key_data_types: Vec<DataType>,
422 state_join_key_indices: Vec<usize>,
423 state_all_data_types: Vec<DataType>,
424 state_table: StateTable<S>,
425 state_pk_indices: Vec<usize>,
426 degree_state: Option<TableInner<S>>,
427 null_matched: K::Bitmap,
428 pk_contained_in_jk: bool,
429 inequality_key_idx: Option<usize>,
430 metrics: Arc<StreamingMetrics>,
431 actor_id: ActorId,
432 fragment_id: FragmentId,
433 side: &'static str,
434 ) -> Self {
435 let alloc = StatsAlloc::new(Global).shared();
436 let pk_data_types = state_pk_indices
438 .iter()
439 .map(|i| state_all_data_types[*i].clone())
440 .collect();
441 let pk_serializer = OrderedRowSerde::new(
442 pk_data_types,
443 vec![OrderType::ascending(); state_pk_indices.len()],
444 );
445
446 let inequality_key_desc = inequality_key_idx.map(|idx| {
447 let serializer = OrderedRowSerde::new(
448 vec![state_all_data_types[idx].clone()],
449 vec![OrderType::ascending()],
450 );
451 InequalityKeyDesc { idx, serializer }
452 });
453
454 let join_table_id = state_table.table_id();
455 let state = TableInner {
456 pk_indices: state_pk_indices,
457 join_key_indices: state_join_key_indices,
458 order_key_indices: state_table.pk_indices().to_vec(),
459 table: state_table,
460 };
461
462 let need_degree_table = degree_state.is_some();
463
464 let metrics_info = MetricsInfo::new(
465 metrics.clone(),
466 join_table_id,
467 actor_id,
468 format!("hash join {}", side),
469 );
470
471 let cache = ManagedLruCache::unbounded_with_hasher_in(
472 watermark_sequence,
473 metrics_info,
474 PrecomputedBuildHasher,
475 alloc,
476 );
477
478 Self {
479 inner: cache,
480 join_key_data_types,
481 null_matched,
482 pk_serializer,
483 state,
484 degree_state,
485 need_degree_table,
486 pk_contained_in_jk,
487 inequality_key_desc,
488 metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
489 }
490 }
491
492 pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
493 self.state.table.init_epoch(epoch).await?;
494 if let Some(degree_state) = &mut self.degree_state {
495 degree_state.table.init_epoch(epoch).await?;
496 }
497 Ok(())
498 }
499}
500
501impl<K: HashKey, S: StateStore> JoinHashMapPostCommit<'_, K, S> {
502 pub async fn post_yield_barrier(
503 self,
504 vnode_bitmap: Option<Arc<Bitmap>>,
505 ) -> StreamExecutorResult<Option<bool>> {
506 let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
507 if let Some(degree_state) = self.degree_state {
508 let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
509 }
510 let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
511 if cache_may_stale.unwrap_or(false) {
512 self.inner.clear();
513 }
514 Ok(cache_may_stale)
515 }
516}
517impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
518 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
519 self.state.table.update_watermark(watermark.clone());
521 if let Some(degree_state) = &mut self.degree_state {
522 degree_state.table.update_watermark(watermark);
523 }
524 }
525
526 pub fn take_state_opt(&mut self, key: &K) -> CacheResult {
535 self.metrics.total_lookup_count += 1;
536 if self.inner.contains(key) {
537 tracing::trace!("hit cache for join key: {:?}", key);
538 let mut state = self.inner.peek_mut(key).unwrap();
541 CacheResult::Hit(state.take())
542 } else {
543 tracing::trace!("miss cache for join key: {:?}", key);
544 CacheResult::Miss
545 }
546 }
547
548 pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType> {
557 self.metrics.total_lookup_count += 1;
558 let state = if self.inner.contains(key) {
559 let mut state = self.inner.peek_mut(key).unwrap();
562 state.take()
563 } else {
564 self.metrics.lookup_miss_count += 1;
565 self.fetch_cached_state(key).await?.into()
566 };
567 Ok(state)
568 }
569
570 async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState> {
573 let key = key.deserialize(&self.join_key_data_types)?;
574
575 let mut entry_state = JoinEntryState::default();
576
577 if self.need_degree_table {
578 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
579 &(Bound::Unbounded, Bound::Unbounded);
580 let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
581 &key,
582 sub_range,
583 PrefetchOptions::default(),
584 );
585 let degree_state = self.degree_state.as_ref().unwrap();
586 let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
587 &key,
588 sub_range,
589 PrefetchOptions::default(),
590 );
591
592 let (table_iter, degree_table_iter) =
593 try_join(table_iter_fut, degree_table_iter_fut).await?;
594
595 let mut pinned_table_iter = std::pin::pin!(table_iter);
596 let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
597
598 let mut rows = vec![];
601 let mut degree_rows = vec![];
602 let mut inconsistency_happened = false;
603 loop {
604 let (row, degree_row) =
605 join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
606 let (row, degree_row) = match (row, degree_row) {
607 (None, None) => break,
608 (None, Some(_)) => {
609 inconsistency_happened = true;
610 consistency_panic!(
611 "mismatched row and degree table of join key: {:?}, degree table has more rows",
612 &key
613 );
614 break;
615 }
616 (Some(_), None) => {
617 inconsistency_happened = true;
618 consistency_panic!(
619 "mismatched row and degree table of join key: {:?}, input table has more rows",
620 &key
621 );
622 break;
623 }
624 (Some(r), Some(d)) => (r, d),
625 };
626
627 let row = row?;
628 let degree_row = degree_row?;
629 rows.push(row);
630 degree_rows.push(degree_row);
631 }
632
633 if inconsistency_happened {
634 assert_ne!(rows.len(), degree_rows.len());
636
637 let row_iter = stream::iter(rows.into_iter()).peekable();
638 let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
639 pin_mut!(row_iter);
640 pin_mut!(degree_row_iter);
641
642 loop {
643 match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
644 (None, _) | (_, None) => break,
645 (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
646 Ordering::Greater => {
647 degree_row_iter.next().await;
648 }
649 Ordering::Less => {
650 row_iter.next().await;
651 }
652 Ordering::Equal => {
653 let row = row_iter.next().await.unwrap();
654 let degree_row = degree_row_iter.next().await.unwrap();
655
656 let pk = row
657 .as_ref()
658 .project(&self.state.pk_indices)
659 .memcmp_serialize(&self.pk_serializer);
660 let degree_i64 = degree_row
661 .datum_at(degree_row.len() - 1)
662 .expect("degree should not be NULL");
663 let inequality_key = self
664 .inequality_key_desc
665 .as_ref()
666 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
667 entry_state
668 .insert(
669 pk,
670 JoinRow::new(row.row(), degree_i64.into_int64() as u64)
671 .encode(),
672 inequality_key,
673 )
674 .with_context(|| self.state.error_context(row.row()))?;
675 }
676 },
677 }
678 }
679 } else {
680 assert_eq!(rows.len(), degree_rows.len());
685
686 #[for_await]
687 for (row, degree_row) in
688 stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
689 {
690 let pk1 = row.key();
691 let pk2 = degree_row.key();
692 debug_assert_eq!(
693 pk1, pk2,
694 "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
695 );
696 let pk = row
697 .as_ref()
698 .project(&self.state.pk_indices)
699 .memcmp_serialize(&self.pk_serializer);
700 let inequality_key = self
701 .inequality_key_desc
702 .as_ref()
703 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
704 let degree_i64 = degree_row
705 .datum_at(degree_row.len() - 1)
706 .expect("degree should not be NULL");
707 entry_state
708 .insert(
709 pk,
710 JoinRow::new(row.row(), degree_i64.into_int64() as u64).encode(),
711 inequality_key,
712 )
713 .with_context(|| self.state.error_context(row.row()))?;
714 }
715 }
716 } else {
717 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
718 &(Bound::Unbounded, Bound::Unbounded);
719 let table_iter = self
720 .state
721 .table
722 .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
723 .await?;
724
725 #[for_await]
726 for entry in table_iter {
727 let row = entry?;
728 let pk = row
729 .as_ref()
730 .project(&self.state.pk_indices)
731 .memcmp_serialize(&self.pk_serializer);
732 let inequality_key = self
733 .inequality_key_desc
734 .as_ref()
735 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
736 entry_state
737 .insert(pk, JoinRow::new(row.row(), 0).encode(), inequality_key)
738 .with_context(|| self.state.error_context(row.row()))?;
739 }
740 };
741
742 Ok(entry_state)
743 }
744
745 pub async fn flush(
746 &mut self,
747 epoch: EpochPair,
748 ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S>> {
749 self.metrics.flush();
750 let state_post_commit = self.state.table.commit(epoch).await?;
751 let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
752 Some(degree_state.table.commit(epoch).await?)
753 } else {
754 None
755 };
756 Ok(JoinHashMapPostCommit {
757 state: state_post_commit,
758 degree_state: degree_state_post_commit,
759 inner: &mut self.inner,
760 })
761 }
762
763 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
764 self.state.table.try_flush().await?;
765 if let Some(degree_state) = &mut self.degree_state {
766 degree_state.table.try_flush().await?;
767 }
768 Ok(())
769 }
770
771 pub fn insert_handle_degree(
772 &mut self,
773 key: &K,
774 value: JoinRow<impl Row>,
775 ) -> StreamExecutorResult<()> {
776 if self.need_degree_table {
777 self.insert(key, value)
778 } else {
779 self.insert_row(key, value.row)
780 }
781 }
782
783 pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
785 let pk = self.serialize_pk_from_row(&value.row);
786
787 let inequality_key = self
788 .inequality_key_desc
789 .as_ref()
790 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
791
792 if self.inner.contains(key) {
795 let mut entry = self.inner.get_mut(key).unwrap();
797 entry
798 .insert(pk, value.encode(), inequality_key)
799 .with_context(|| self.state.error_context(&value.row))?;
800 } else if self.pk_contained_in_jk {
801 self.metrics.insert_cache_miss_count += 1;
803 let mut entry = JoinEntryState::default();
804 entry
805 .insert(pk, value.encode(), inequality_key)
806 .with_context(|| self.state.error_context(&value.row))?;
807 self.update_state(key, entry.into());
808 }
809
810 if let Some(degree_state) = self.degree_state.as_mut() {
812 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
813 self.state.table.insert(row);
814 degree_state.table.insert(degree);
815 } else {
816 self.state.table.insert(value.row);
817 }
818 Ok(())
819 }
820
821 pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
824 let join_row = JoinRow::new(&value, 0);
825 self.insert(key, join_row)?;
826 Ok(())
827 }
828
829 pub fn delete_handle_degree(
830 &mut self,
831 key: &K,
832 value: JoinRow<impl Row>,
833 ) -> StreamExecutorResult<()> {
834 if self.need_degree_table {
835 self.delete(key, value)
836 } else {
837 self.delete_row(key, value.row)
838 }
839 }
840
841 pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
843 if let Some(mut entry) = self.inner.get_mut(key) {
844 let pk = (&value.row)
845 .project(&self.state.pk_indices)
846 .memcmp_serialize(&self.pk_serializer);
847 let inequality_key = self
848 .inequality_key_desc
849 .as_ref()
850 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
851 entry
852 .remove(pk, inequality_key.as_ref())
853 .with_context(|| self.state.error_context(&value.row))?;
854 }
855
856 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
858 self.state.table.delete(row);
859 let degree_state = self.degree_state.as_mut().unwrap();
860 degree_state.table.delete(degree);
861 Ok(())
862 }
863
864 pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
867 if let Some(mut entry) = self.inner.get_mut(key) {
868 let pk = (&value)
869 .project(&self.state.pk_indices)
870 .memcmp_serialize(&self.pk_serializer);
871
872 let inequality_key = self
873 .inequality_key_desc
874 .as_ref()
875 .map(|desc| desc.serialize_inequal_key_from_row(&value));
876 entry
877 .remove(pk, inequality_key.as_ref())
878 .with_context(|| self.state.error_context(&value))?;
879 }
880
881 self.state.table.delete(value);
883 Ok(())
884 }
885
886 pub fn update_state(&mut self, key: &K, state: HashValueType) {
888 self.inner.put(key.clone(), HashValueWrapper(Some(state)));
889 }
890
891 pub fn evict(&mut self) {
893 self.inner.evict();
894 }
895
896 pub fn entry_count(&self) -> usize {
898 self.inner.len()
899 }
900
901 pub fn null_matched(&self) -> &K::Bitmap {
902 &self.null_matched
903 }
904
905 pub fn table_id(&self) -> u32 {
906 self.state.table.table_id()
907 }
908
909 pub fn join_key_data_types(&self) -> &[DataType] {
910 &self.join_key_data_types
911 }
912
913 pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
917 let desc = self.inequality_key_desc.as_ref().unwrap();
918 row.datum_at(desc.idx).is_none()
919 }
920
921 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
925 self.inequality_key_desc
926 .as_ref()
927 .unwrap()
928 .serialize_inequal_key_from_row(&row)
929 }
930
931 pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
932 row.project(&self.state.pk_indices)
933 .memcmp_serialize(&self.pk_serializer)
934 }
935}
936
937#[must_use]
938pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore> {
939 state: StateTablePostCommit<'a, S>,
940 degree_state: Option<StateTablePostCommit<'a, S>>,
941 inner: &'a mut JoinHashMapInner<K>,
942}
943
944use risingwave_common_estimate_size::KvSize;
945use thiserror::Error;
946
947use super::*;
948use crate::executor::prelude::{Stream, try_stream};
949
950#[derive(Default)]
956pub struct JoinEntryState {
957 cached: JoinRowSet<PkType, StateValueType>,
959 inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
961 kv_heap_size: KvSize,
962}
963
964impl EstimateSize for JoinEntryState {
965 fn estimated_heap_size(&self) -> usize {
966 self.kv_heap_size.size()
969 }
970}
971
972#[derive(Error, Debug)]
973pub enum JoinEntryError {
974 #[error("double inserting a join state entry")]
975 Occupied,
976 #[error("removing a join state entry but it is not in the cache")]
977 Remove,
978 #[error("retrieving a pk from the inequality index but it is not in the cache")]
979 InequalIndex,
980}
981
982impl JoinEntryState {
983 pub fn insert(
985 &mut self,
986 key: PkType,
987 value: StateValueType,
988 inequality_key: Option<InequalKeyType>,
989 ) -> Result<&mut StateValueType, JoinEntryError> {
990 let mut removed = false;
991 if !enable_strict_consistency() {
992 if let Some(old_value) = self.cached.remove(&key) {
994 if let Some(inequality_key) = inequality_key.as_ref() {
995 self.remove_pk_from_inequality_index(&key, inequality_key);
996 }
997 self.kv_heap_size.sub(&key, &old_value);
998 removed = true;
999 }
1000 }
1001
1002 self.kv_heap_size.add(&key, &value);
1003
1004 if let Some(inequality_key) = inequality_key {
1005 self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1006 }
1007 let ret = self.cached.try_insert(key.clone(), value);
1008
1009 if !enable_strict_consistency() {
1010 assert!(ret.is_ok(), "we have removed existing entry, if any");
1011 if removed {
1012 consistency_error!(?key, "double inserting a join state entry");
1014 }
1015 }
1016
1017 ret.map_err(|_| JoinEntryError::Occupied)
1018 }
1019
1020 pub fn remove(
1022 &mut self,
1023 pk: PkType,
1024 inequality_key: Option<&InequalKeyType>,
1025 ) -> Result<(), JoinEntryError> {
1026 if let Some(value) = self.cached.remove(&pk) {
1027 self.kv_heap_size.sub(&pk, &value);
1028 if let Some(inequality_key) = inequality_key {
1029 self.remove_pk_from_inequality_index(&pk, inequality_key);
1030 }
1031 Ok(())
1032 } else if enable_strict_consistency() {
1033 Err(JoinEntryError::Remove)
1034 } else {
1035 consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1036 Ok(())
1037 }
1038 }
1039
1040 fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1041 if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1042 if pk_set.remove(pk).is_none() {
1043 if enable_strict_consistency() {
1044 panic!("removing a pk that it not in the inequality index");
1045 } else {
1046 consistency_error!(?pk, "removing a pk that it not in the inequality index");
1047 };
1048 } else {
1049 self.kv_heap_size.sub(pk, &());
1050 }
1051 if pk_set.is_empty() {
1052 self.inequality_index.remove(inequality_key);
1053 }
1054 }
1055 }
1056
1057 fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1058 if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1059 let pk_size = pk.estimated_size();
1060 if pk_set.try_insert(pk, ()).is_err() {
1061 if enable_strict_consistency() {
1062 panic!("inserting a pk that it already in the inequality index");
1063 } else {
1064 consistency_error!("inserting a pk that it already in the inequality index");
1065 };
1066 } else {
1067 self.kv_heap_size.add_size(pk_size);
1068 }
1069 } else {
1070 let mut pk_set = JoinRowSet::default();
1071 pk_set.try_insert(pk, ()).unwrap();
1072 self.inequality_index
1073 .try_insert(inequality_key, pk_set)
1074 .unwrap();
1075 }
1076 }
1077
1078 pub fn get(
1079 &self,
1080 pk: &PkType,
1081 data_types: &[DataType],
1082 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1083 self.cached
1084 .get(pk)
1085 .map(|encoded| encoded.decode(data_types))
1086 }
1087
1088 pub fn values_mut<'a>(
1094 &'a mut self,
1095 data_types: &'a [DataType],
1096 ) -> impl Iterator<
1097 Item = (
1098 &'a mut StateValueType,
1099 StreamExecutorResult<JoinRow<OwnedRow>>,
1100 ),
1101 > + 'a {
1102 self.cached.values_mut().map(|encoded| {
1103 let decoded = encoded.decode(data_types);
1104 (encoded, decoded)
1105 })
1106 }
1107
1108 pub fn len(&self) -> usize {
1109 self.cached.len()
1110 }
1111
1112 pub fn range_by_inequality<'a, R>(
1114 &'a self,
1115 range: R,
1116 data_types: &'a [DataType],
1117 ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<OwnedRow>>> + 'a
1118 where
1119 R: RangeBounds<InequalKeyType> + 'a,
1120 {
1121 self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1122 pk_set
1123 .keys()
1124 .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1125 })
1126 }
1127
1128 pub fn upper_bound_by_inequality<'a>(
1130 &'a self,
1131 bound: Bound<&InequalKeyType>,
1132 data_types: &'a [DataType],
1133 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1134 if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1135 if let Some(pk) = pk_set.first_key_sorted() {
1136 self.get_by_indexed_pk(pk, data_types)
1137 } else {
1138 panic!("pk set for a index record must has at least one element");
1139 }
1140 } else {
1141 None
1142 }
1143 }
1144
1145 pub fn get_by_indexed_pk(
1146 &self,
1147 pk: &PkType,
1148 data_types: &[DataType],
1149 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>>
1150where {
1151 if let Some(value) = self.cached.get(pk) {
1152 Some(value.decode(data_types))
1153 } else if enable_strict_consistency() {
1154 Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1155 } else {
1156 consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1157 None
1158 }
1159 }
1160
1161 pub fn lower_bound_by_inequality<'a>(
1163 &'a self,
1164 bound: Bound<&InequalKeyType>,
1165 data_types: &'a [DataType],
1166 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1167 if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1168 if let Some(pk) = pk_set.first_key_sorted() {
1169 self.get_by_indexed_pk(pk, data_types)
1170 } else {
1171 panic!("pk set for a index record must has at least one element");
1172 }
1173 } else {
1174 None
1175 }
1176 }
1177
1178 pub fn get_first_by_inequality<'a>(
1179 &'a self,
1180 inequality_key: &InequalKeyType,
1181 data_types: &'a [DataType],
1182 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1183 if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1184 if let Some(pk) = pk_set.first_key_sorted() {
1185 self.get_by_indexed_pk(pk, data_types)
1186 } else {
1187 panic!("pk set for a index record must has at least one element");
1188 }
1189 } else {
1190 None
1191 }
1192 }
1193
1194 pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1195 &self.inequality_index
1196 }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201 use itertools::Itertools;
1202 use risingwave_common::array::*;
1203 use risingwave_common::util::iter_util::ZipEqDebug;
1204
1205 use super::*;
1206
1207 fn insert_chunk(
1208 managed_state: &mut JoinEntryState,
1209 pk_indices: &[usize],
1210 col_types: &[DataType],
1211 inequality_key_idx: Option<usize>,
1212 data_chunk: &DataChunk,
1213 ) {
1214 let pk_col_type = pk_indices
1215 .iter()
1216 .map(|idx| col_types[*idx].clone())
1217 .collect_vec();
1218 let pk_serializer =
1219 OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1220 let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1221 let inequality_key_serializer = inequality_key_type
1222 .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1223 for row_ref in data_chunk.rows() {
1224 let row: OwnedRow = row_ref.into_owned_row();
1225 let value_indices = (0..row.len() - 1).collect_vec();
1226 let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1227 let pk = OwnedRow::new(pk)
1229 .project(&value_indices)
1230 .memcmp_serialize(&pk_serializer);
1231 let inequality_key = inequality_key_idx.map(|idx| {
1232 (&row)
1233 .project(&[idx])
1234 .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1235 });
1236 let join_row = JoinRow { row, degree: 0 };
1237 managed_state
1238 .insert(pk, join_row.encode(), inequality_key)
1239 .unwrap();
1240 }
1241 }
1242
1243 fn check(
1244 managed_state: &mut JoinEntryState,
1245 col_types: &[DataType],
1246 col1: &[i64],
1247 col2: &[i64],
1248 ) {
1249 for ((_, matched_row), (d1, d2)) in managed_state
1250 .values_mut(col_types)
1251 .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1252 {
1253 let matched_row = matched_row.unwrap();
1254 assert_eq!(matched_row.row[0], Some(ScalarImpl::Int64(*d1)));
1255 assert_eq!(matched_row.row[1], Some(ScalarImpl::Int64(*d2)));
1256 assert_eq!(matched_row.degree, 0);
1257 }
1258 }
1259
1260 #[tokio::test]
1261 async fn test_managed_join_state() {
1262 let mut managed_state = JoinEntryState::default();
1263 let col_types = vec![DataType::Int64, DataType::Int64];
1264 let pk_indices = [0];
1265
1266 let col1 = [3, 2, 1];
1267 let col2 = [4, 5, 6];
1268 let data_chunk1 = DataChunk::from_pretty(
1269 "I I
1270 3 4
1271 2 5
1272 1 6",
1273 );
1274
1275 insert_chunk(
1277 &mut managed_state,
1278 &pk_indices,
1279 &col_types,
1280 None,
1281 &data_chunk1,
1282 );
1283 check(&mut managed_state, &col_types, &col1, &col2);
1284
1285 let col1 = [1, 2, 3, 4, 5];
1287 let col2 = [6, 5, 4, 9, 8];
1288 let data_chunk2 = DataChunk::from_pretty(
1289 "I I
1290 5 8
1291 4 9",
1292 );
1293 insert_chunk(
1294 &mut managed_state,
1295 &pk_indices,
1296 &col_types,
1297 None,
1298 &data_chunk2,
1299 );
1300 check(&mut managed_state, &col_types, &col1, &col2);
1301 }
1302
1303 #[tokio::test]
1304 async fn test_managed_join_state_w_inequality_index() {
1305 let mut managed_state = JoinEntryState::default();
1306 let col_types = vec![DataType::Int64, DataType::Int64];
1307 let pk_indices = [0];
1308 let inequality_key_idx = Some(1);
1309 let inequality_key_serializer =
1310 OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1311
1312 let col1 = [3, 2, 1];
1313 let col2 = [4, 5, 5];
1314 let data_chunk1 = DataChunk::from_pretty(
1315 "I I
1316 3 4
1317 2 5
1318 1 5",
1319 );
1320
1321 insert_chunk(
1323 &mut managed_state,
1324 &pk_indices,
1325 &col_types,
1326 inequality_key_idx,
1327 &data_chunk1,
1328 );
1329 check(&mut managed_state, &col_types, &col1, &col2);
1330 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1331 .memcmp_serialize(&inequality_key_serializer);
1332 let row = managed_state
1333 .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1334 .unwrap()
1335 .unwrap();
1336 assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1337 let row = managed_state
1338 .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1339 .unwrap()
1340 .unwrap();
1341 assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1342
1343 let col1 = [1, 2, 3, 4, 5];
1345 let col2 = [5, 5, 4, 4, 8];
1346 let data_chunk2 = DataChunk::from_pretty(
1347 "I I
1348 5 8
1349 4 4",
1350 );
1351 insert_chunk(
1352 &mut managed_state,
1353 &pk_indices,
1354 &col_types,
1355 inequality_key_idx,
1356 &data_chunk2,
1357 );
1358 check(&mut managed_state, &col_types, &col1, &col2);
1359
1360 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1361 .memcmp_serialize(&inequality_key_serializer);
1362 let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1363 assert!(row.is_none());
1364 }
1365}