1use std::alloc::Global;
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
25use risingwave_common::bitmap::Bitmap;
26use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
27use risingwave_common::metrics::LabelGuardedIntCounter;
28use risingwave_common::row::{OwnedRow, Row, RowExt, once};
29use risingwave_common::types::{DataType, ScalarImpl};
30use risingwave_common::util::epoch::EpochPair;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common::util::row_serde::OrderedRowSerde;
33use risingwave_common::util::sort_util::OrderType;
34use risingwave_common_estimate_size::EstimateSize;
35use risingwave_storage::StateStore;
36use risingwave_storage::store::PrefetchOptions;
37use thiserror_ext::AsReport;
38
39use super::row::{CachedJoinRow, DegreeType};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::error::StreamExecutorResult;
45use crate::executor::join::row::JoinRow;
46use crate::executor::monitor::StreamingMetrics;
47use crate::executor::{JoinEncoding, StreamExecutorError};
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type HashValueType<E> = Box<JoinEntryState<E>>;
55
56impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
57 fn estimated_heap_size(&self) -> usize {
58 self.as_ref().estimated_heap_size()
59 }
60}
61
62struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
68
69pub(crate) enum CacheResult<E: JoinEncoding> {
70 NeverMatch, Miss, Hit(HashValueType<E>), }
74
75impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
76 fn estimated_heap_size(&self) -> usize {
77 self.0.estimated_heap_size()
78 }
79}
80
81impl<E: JoinEncoding> HashValueWrapper<E> {
82 const MESSAGE: &'static str = "the state should always be `Some`";
83
84 pub fn take(&mut self) -> HashValueType<E> {
86 self.0.take().expect(Self::MESSAGE)
87 }
88}
89
90impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
91 type Target = HashValueType<E>;
92
93 fn deref(&self) -> &Self::Target {
94 self.0.as_ref().expect(Self::MESSAGE)
95 }
96}
97
98impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
99 fn deref_mut(&mut self) -> &mut Self::Target {
100 self.0.as_mut().expect(Self::MESSAGE)
101 }
102}
103
104type JoinHashMapInner<K, E> =
105 ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher, SharedStatsAlloc<Global>>;
106
107pub struct JoinHashMapMetrics {
108 lookup_miss_count: usize,
111 total_lookup_count: usize,
112 insert_cache_miss_count: usize,
114
115 join_lookup_total_count_metric: LabelGuardedIntCounter,
117 join_lookup_miss_count_metric: LabelGuardedIntCounter,
118 join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
119}
120
121impl JoinHashMapMetrics {
122 pub fn new(
123 metrics: &StreamingMetrics,
124 actor_id: ActorId,
125 fragment_id: FragmentId,
126 side: &'static str,
127 join_table_id: u32,
128 ) -> Self {
129 let actor_id = actor_id.to_string();
130 let fragment_id = fragment_id.to_string();
131 let join_table_id = join_table_id.to_string();
132 let join_lookup_total_count_metric = metrics
133 .join_lookup_total_count
134 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
135 let join_lookup_miss_count_metric = metrics
136 .join_lookup_miss_count
137 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
138 let join_insert_cache_miss_count_metrics = metrics
139 .join_insert_cache_miss_count
140 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
141
142 Self {
143 lookup_miss_count: 0,
144 total_lookup_count: 0,
145 insert_cache_miss_count: 0,
146 join_lookup_total_count_metric,
147 join_lookup_miss_count_metric,
148 join_insert_cache_miss_count_metrics,
149 }
150 }
151
152 pub fn flush(&mut self) {
153 self.join_lookup_total_count_metric
154 .inc_by(self.total_lookup_count as u64);
155 self.join_lookup_miss_count_metric
156 .inc_by(self.lookup_miss_count as u64);
157 self.join_insert_cache_miss_count_metrics
158 .inc_by(self.insert_cache_miss_count as u64);
159 self.total_lookup_count = 0;
160 self.lookup_miss_count = 0;
161 self.insert_cache_miss_count = 0;
162 }
163}
164
165struct InequalityKeyDesc {
167 idx: usize,
168 serializer: OrderedRowSerde,
169}
170
171impl InequalityKeyDesc {
172 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
174 let indices = vec![self.idx];
175 let inequality_key = row.project(&indices);
176 inequality_key.memcmp_serialize(&self.serializer)
177 }
178}
179
180pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
181 inner: JoinHashMapInner<K, E>,
183 join_key_data_types: Vec<DataType>,
185 null_matched: K::Bitmap,
187 pk_serializer: OrderedRowSerde,
189 state: TableInner<S>,
191 degree_state: Option<TableInner<S>>,
222 need_degree_table: bool,
225 pk_contained_in_jk: bool,
227 inequality_key_desc: Option<InequalityKeyDesc>,
229 metrics: JoinHashMapMetrics,
231 _marker: std::marker::PhantomData<E>,
232}
233
234impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
235 pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
236 (&self.state.order_key_indices, &mut self.degree_state)
237 }
238
239 pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
246 &'a mut self,
247 key: &'a K,
248 ) -> StreamExecutorResult<(
249 impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
250 &'a [usize],
251 &'a mut Option<TableInner<S>>,
252 )> {
253 let degree_state = &mut self.degree_state;
254 let (order_key_indices, pk_indices, state_table) = (
255 &self.state.order_key_indices,
256 &self.state.pk_indices,
257 &mut self.state.table,
258 );
259 let degrees = if let Some(degree_state) = degree_state {
260 Some(fetch_degrees(key, &self.join_key_data_types, °ree_state.table).await?)
261 } else {
262 None
263 };
264 let stream = into_stream(
265 &self.join_key_data_types,
266 pk_indices,
267 &self.pk_serializer,
268 state_table,
269 key,
270 degrees,
271 );
272 Ok((stream, order_key_indices, &mut self.degree_state))
273 }
274}
275
276#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
277pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
278 join_key_data_types: &'a [DataType],
279 pk_indices: &'a [usize],
280 pk_serializer: &'a OrderedRowSerde,
281 state_table: &'a StateTable<S>,
282 key: &'a K,
283 degrees: Option<Vec<DegreeType>>,
284) {
285 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
286 let decoded_key = key.deserialize(join_key_data_types)?;
287 let table_iter = state_table
288 .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
289 .await?;
290
291 #[for_await]
292 for (i, entry) in table_iter.enumerate() {
293 let encoded_row = entry?;
294 let encoded_pk = encoded_row
295 .as_ref()
296 .project(pk_indices)
297 .memcmp_serialize(pk_serializer);
298 let join_row = JoinRow::new(
299 encoded_row.into_owned_row(),
300 degrees.as_ref().map_or(0, |d| d[i]),
301 );
302 yield (encoded_pk, join_row);
303 }
304}
305
306async fn fetch_degrees<K: HashKey, S: StateStore>(
330 key: &K,
331 join_key_data_types: &[DataType],
332 degree_state_table: &StateTable<S>,
333) -> StreamExecutorResult<Vec<DegreeType>> {
334 let key = key.deserialize(join_key_data_types)?;
335 let mut degrees = vec![];
336 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
337 let table_iter = degree_state_table
338 .iter_with_prefix(key, sub_range, PrefetchOptions::default())
339 .await
340 .unwrap();
341 #[for_await]
342 for entry in table_iter {
343 let degree_row = entry?;
344 let degree_i64 = degree_row
345 .datum_at(degree_row.len() - 1)
346 .expect("degree should not be NULL");
347 degrees.push(degree_i64.into_int64() as u64);
348 }
349 Ok(degrees)
350}
351
352pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
356 order_key_indices: &[usize],
357 degree_state: &mut TableInner<S>,
358 matched_row: &mut JoinRow<OwnedRow>,
359) {
360 let old_degree_row = matched_row
361 .row
362 .as_ref()
363 .project(order_key_indices)
364 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
365 if INCREMENT {
366 matched_row.degree += 1;
367 } else {
368 matched_row.degree -= 1;
370 }
371 let new_degree_row = matched_row
372 .row
373 .as_ref()
374 .project(order_key_indices)
375 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
376 degree_state.table.update(old_degree_row, new_degree_row);
377}
378
379pub struct TableInner<S: StateStore> {
380 pk_indices: Vec<usize>,
382 join_key_indices: Vec<usize>,
384 order_key_indices: Vec<usize>,
389 pub(crate) table: StateTable<S>,
390}
391
392impl<S: StateStore> TableInner<S> {
393 pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
394 let order_key_indices = table.pk_indices().to_vec();
395 Self {
396 pk_indices,
397 join_key_indices,
398 order_key_indices,
399 table,
400 }
401 }
402
403 fn error_context(&self, row: &impl Row) -> String {
404 let pk = row.project(&self.pk_indices);
405 let jk = row.project(&self.join_key_indices);
406 format!(
407 "join key: {}, pk: {}, row: {}, state_table_id: {}",
408 jk.display(),
409 pk.display(),
410 row.display(),
411 self.table.table_id()
412 )
413 }
414}
415
416impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
417 #[allow(clippy::too_many_arguments)]
419 pub fn new(
420 watermark_sequence: AtomicU64Ref,
421 join_key_data_types: Vec<DataType>,
422 state_join_key_indices: Vec<usize>,
423 state_all_data_types: Vec<DataType>,
424 state_table: StateTable<S>,
425 state_pk_indices: Vec<usize>,
426 degree_state: Option<TableInner<S>>,
427 null_matched: K::Bitmap,
428 pk_contained_in_jk: bool,
429 inequality_key_idx: Option<usize>,
430 metrics: Arc<StreamingMetrics>,
431 actor_id: ActorId,
432 fragment_id: FragmentId,
433 side: &'static str,
434 ) -> Self {
435 let alloc = StatsAlloc::new(Global).shared();
436 let pk_data_types = state_pk_indices
438 .iter()
439 .map(|i| state_all_data_types[*i].clone())
440 .collect();
441 let pk_serializer = OrderedRowSerde::new(
442 pk_data_types,
443 vec![OrderType::ascending(); state_pk_indices.len()],
444 );
445
446 let inequality_key_desc = inequality_key_idx.map(|idx| {
447 let serializer = OrderedRowSerde::new(
448 vec![state_all_data_types[idx].clone()],
449 vec![OrderType::ascending()],
450 );
451 InequalityKeyDesc { idx, serializer }
452 });
453
454 let join_table_id = state_table.table_id();
455 let state = TableInner {
456 pk_indices: state_pk_indices,
457 join_key_indices: state_join_key_indices,
458 order_key_indices: state_table.pk_indices().to_vec(),
459 table: state_table,
460 };
461
462 let need_degree_table = degree_state.is_some();
463
464 let metrics_info = MetricsInfo::new(
465 metrics.clone(),
466 join_table_id,
467 actor_id,
468 format!("hash join {}", side),
469 );
470
471 let cache = ManagedLruCache::unbounded_with_hasher_in(
472 watermark_sequence,
473 metrics_info,
474 PrecomputedBuildHasher,
475 alloc,
476 );
477
478 Self {
479 inner: cache,
480 join_key_data_types,
481 null_matched,
482 pk_serializer,
483 state,
484 degree_state,
485 need_degree_table,
486 pk_contained_in_jk,
487 inequality_key_desc,
488 metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
489 _marker: std::marker::PhantomData,
490 }
491 }
492
493 pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
494 self.state.table.init_epoch(epoch).await?;
495 if let Some(degree_state) = &mut self.degree_state {
496 degree_state.table.init_epoch(epoch).await?;
497 }
498 Ok(())
499 }
500}
501
502impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
503 pub async fn post_yield_barrier(
504 self,
505 vnode_bitmap: Option<Arc<Bitmap>>,
506 ) -> StreamExecutorResult<Option<bool>> {
507 let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
508 if let Some(degree_state) = self.degree_state {
509 let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
510 }
511 let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
512 if cache_may_stale.unwrap_or(false) {
513 self.inner.clear();
514 }
515 Ok(cache_may_stale)
516 }
517}
518impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
519 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
520 self.state.table.update_watermark(watermark.clone());
522 if let Some(degree_state) = &mut self.degree_state {
523 degree_state.table.update_watermark(watermark);
524 }
525 }
526
527 pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
536 self.metrics.total_lookup_count += 1;
537 if self.inner.contains(key) {
538 tracing::trace!("hit cache for join key: {:?}", key);
539 let mut state = self.inner.peek_mut(key).unwrap();
542 CacheResult::Hit(state.take())
543 } else {
544 tracing::trace!("miss cache for join key: {:?}", key);
545 CacheResult::Miss
546 }
547 }
548
549 pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
558 self.metrics.total_lookup_count += 1;
559 let state = if self.inner.contains(key) {
560 let mut state = self.inner.peek_mut(key).unwrap();
563 state.take()
564 } else {
565 self.metrics.lookup_miss_count += 1;
566 self.fetch_cached_state(key).await?.into()
567 };
568 Ok(state)
569 }
570
571 async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
574 let key = key.deserialize(&self.join_key_data_types)?;
575
576 let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
577
578 if self.need_degree_table {
579 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
580 &(Bound::Unbounded, Bound::Unbounded);
581 let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
582 &key,
583 sub_range,
584 PrefetchOptions::default(),
585 );
586 let degree_state = self.degree_state.as_ref().unwrap();
587 let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
588 &key,
589 sub_range,
590 PrefetchOptions::default(),
591 );
592
593 let (table_iter, degree_table_iter) =
594 try_join(table_iter_fut, degree_table_iter_fut).await?;
595
596 let mut pinned_table_iter = std::pin::pin!(table_iter);
597 let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
598
599 let mut rows = vec![];
602 let mut degree_rows = vec![];
603 let mut inconsistency_happened = false;
604 loop {
605 let (row, degree_row) =
606 join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
607 let (row, degree_row) = match (row, degree_row) {
608 (None, None) => break,
609 (None, Some(_)) => {
610 inconsistency_happened = true;
611 consistency_panic!(
612 "mismatched row and degree table of join key: {:?}, degree table has more rows",
613 &key
614 );
615 break;
616 }
617 (Some(_), None) => {
618 inconsistency_happened = true;
619 consistency_panic!(
620 "mismatched row and degree table of join key: {:?}, input table has more rows",
621 &key
622 );
623 break;
624 }
625 (Some(r), Some(d)) => (r, d),
626 };
627
628 let row = row?;
629 let degree_row = degree_row?;
630 rows.push(row);
631 degree_rows.push(degree_row);
632 }
633
634 if inconsistency_happened {
635 assert_ne!(rows.len(), degree_rows.len());
637
638 let row_iter = stream::iter(rows.into_iter()).peekable();
639 let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
640 pin_mut!(row_iter);
641 pin_mut!(degree_row_iter);
642
643 loop {
644 match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
645 (None, _) | (_, None) => break,
646 (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
647 Ordering::Greater => {
648 degree_row_iter.next().await;
649 }
650 Ordering::Less => {
651 row_iter.next().await;
652 }
653 Ordering::Equal => {
654 let row = row_iter.next().await.unwrap();
655 let degree_row = degree_row_iter.next().await.unwrap();
656
657 let pk = row
658 .as_ref()
659 .project(&self.state.pk_indices)
660 .memcmp_serialize(&self.pk_serializer);
661 let degree_i64 = degree_row
662 .datum_at(degree_row.len() - 1)
663 .expect("degree should not be NULL");
664 let inequality_key = self
665 .inequality_key_desc
666 .as_ref()
667 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
668 entry_state
669 .insert(
670 pk,
671 E::encode(&JoinRow::new(
672 row.row(),
673 degree_i64.into_int64() as u64,
674 )),
675 inequality_key,
676 )
677 .with_context(|| self.state.error_context(row.row()))?;
678 }
679 },
680 }
681 }
682 } else {
683 assert_eq!(rows.len(), degree_rows.len());
688
689 #[for_await]
690 for (row, degree_row) in
691 stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
692 {
693 let pk1 = row.key();
694 let pk2 = degree_row.key();
695 debug_assert_eq!(
696 pk1, pk2,
697 "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
698 );
699 let pk = row
700 .as_ref()
701 .project(&self.state.pk_indices)
702 .memcmp_serialize(&self.pk_serializer);
703 let inequality_key = self
704 .inequality_key_desc
705 .as_ref()
706 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
707 let degree_i64 = degree_row
708 .datum_at(degree_row.len() - 1)
709 .expect("degree should not be NULL");
710 entry_state
711 .insert(
712 pk,
713 E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
714 inequality_key,
715 )
716 .with_context(|| self.state.error_context(row.row()))?;
717 }
718 }
719 } else {
720 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
721 &(Bound::Unbounded, Bound::Unbounded);
722 let table_iter = self
723 .state
724 .table
725 .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
726 .await?;
727
728 #[for_await]
729 for entry in table_iter {
730 let row = entry?;
731 let pk = row
732 .as_ref()
733 .project(&self.state.pk_indices)
734 .memcmp_serialize(&self.pk_serializer);
735 let inequality_key = self
736 .inequality_key_desc
737 .as_ref()
738 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
739 entry_state
740 .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
741 .with_context(|| self.state.error_context(row.row()))?;
742 }
743 };
744
745 Ok(entry_state)
746 }
747
748 pub async fn flush(
749 &mut self,
750 epoch: EpochPair,
751 ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
752 self.metrics.flush();
753 let state_post_commit = self.state.table.commit(epoch).await?;
754 let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
755 Some(degree_state.table.commit(epoch).await?)
756 } else {
757 None
758 };
759 Ok(JoinHashMapPostCommit {
760 state: state_post_commit,
761 degree_state: degree_state_post_commit,
762 inner: &mut self.inner,
763 })
764 }
765
766 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
767 self.state.table.try_flush().await?;
768 if let Some(degree_state) = &mut self.degree_state {
769 degree_state.table.try_flush().await?;
770 }
771 Ok(())
772 }
773
774 pub fn insert_handle_degree(
775 &mut self,
776 key: &K,
777 value: JoinRow<impl Row>,
778 ) -> StreamExecutorResult<()> {
779 if self.need_degree_table {
780 self.insert(key, value)
781 } else {
782 self.insert_row(key, value.row)
783 }
784 }
785
786 pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
788 let pk = self.serialize_pk_from_row(&value.row);
789
790 let inequality_key = self
791 .inequality_key_desc
792 .as_ref()
793 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
794
795 if self.inner.contains(key) {
798 let mut entry = self.inner.get_mut(key).unwrap();
800 entry
801 .insert(pk, E::encode(&value), inequality_key)
802 .with_context(|| self.state.error_context(&value.row))?;
803 } else if self.pk_contained_in_jk {
804 self.metrics.insert_cache_miss_count += 1;
806 let mut entry: JoinEntryState<E> = JoinEntryState::default();
807 entry
808 .insert(pk, E::encode(&value), inequality_key)
809 .with_context(|| self.state.error_context(&value.row))?;
810 self.update_state(key, entry.into());
811 }
812
813 if let Some(degree_state) = self.degree_state.as_mut() {
815 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
816 self.state.table.insert(row);
817 degree_state.table.insert(degree);
818 } else {
819 self.state.table.insert(value.row);
820 }
821 Ok(())
822 }
823
824 pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
827 let join_row = JoinRow::new(&value, 0);
828 self.insert(key, join_row)?;
829 Ok(())
830 }
831
832 pub fn delete_handle_degree(
833 &mut self,
834 key: &K,
835 value: JoinRow<impl Row>,
836 ) -> StreamExecutorResult<()> {
837 if self.need_degree_table {
838 self.delete(key, value)
839 } else {
840 self.delete_row(key, value.row)
841 }
842 }
843
844 pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
846 if let Some(mut entry) = self.inner.get_mut(key) {
847 let pk = (&value.row)
848 .project(&self.state.pk_indices)
849 .memcmp_serialize(&self.pk_serializer);
850 let inequality_key = self
851 .inequality_key_desc
852 .as_ref()
853 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
854 entry
855 .remove(pk, inequality_key.as_ref())
856 .with_context(|| self.state.error_context(&value.row))?;
857 }
858
859 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
861 self.state.table.delete(row);
862 let degree_state = self.degree_state.as_mut().unwrap();
863 degree_state.table.delete(degree);
864 Ok(())
865 }
866
867 pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
870 if let Some(mut entry) = self.inner.get_mut(key) {
871 let pk = (&value)
872 .project(&self.state.pk_indices)
873 .memcmp_serialize(&self.pk_serializer);
874
875 let inequality_key = self
876 .inequality_key_desc
877 .as_ref()
878 .map(|desc| desc.serialize_inequal_key_from_row(&value));
879 entry
880 .remove(pk, inequality_key.as_ref())
881 .with_context(|| self.state.error_context(&value))?;
882 }
883
884 self.state.table.delete(value);
886 Ok(())
887 }
888
889 pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
891 self.inner.put(key.clone(), HashValueWrapper(Some(state)));
892 }
893
894 pub fn evict(&mut self) {
896 self.inner.evict();
897 }
898
899 pub fn entry_count(&self) -> usize {
901 self.inner.len()
902 }
903
904 pub fn null_matched(&self) -> &K::Bitmap {
905 &self.null_matched
906 }
907
908 pub fn table_id(&self) -> u32 {
909 self.state.table.table_id()
910 }
911
912 pub fn join_key_data_types(&self) -> &[DataType] {
913 &self.join_key_data_types
914 }
915
916 pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
920 let desc = self.inequality_key_desc.as_ref().unwrap();
921 row.datum_at(desc.idx).is_none()
922 }
923
924 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
928 self.inequality_key_desc
929 .as_ref()
930 .unwrap()
931 .serialize_inequal_key_from_row(&row)
932 }
933
934 pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
935 row.project(&self.state.pk_indices)
936 .memcmp_serialize(&self.pk_serializer)
937 }
938}
939
940#[must_use]
941pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
942 state: StateTablePostCommit<'a, S>,
943 degree_state: Option<StateTablePostCommit<'a, S>>,
944 inner: &'a mut JoinHashMapInner<K, E>,
945}
946
947use risingwave_common_estimate_size::KvSize;
948use thiserror::Error;
949
950use super::*;
951use crate::executor::prelude::{Stream, try_stream};
952
953#[derive(Default)]
959pub struct JoinEntryState<E: JoinEncoding> {
960 cached: JoinRowSet<PkType, E::EncodedRow>,
962 inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
964 kv_heap_size: KvSize,
965}
966
967impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
968 fn estimated_heap_size(&self) -> usize {
969 self.kv_heap_size.size()
972 }
973}
974
975#[derive(Error, Debug)]
976pub enum JoinEntryError {
977 #[error("double inserting a join state entry")]
978 Occupied,
979 #[error("removing a join state entry but it is not in the cache")]
980 Remove,
981 #[error("retrieving a pk from the inequality index but it is not in the cache")]
982 InequalIndex,
983}
984
985impl<E: JoinEncoding> JoinEntryState<E> {
986 pub fn insert(
988 &mut self,
989 key: PkType,
990 value: E::EncodedRow,
991 inequality_key: Option<InequalKeyType>,
992 ) -> Result<&mut E::EncodedRow, JoinEntryError> {
993 let mut removed = false;
994 if !enable_strict_consistency() {
995 if let Some(old_value) = self.cached.remove(&key) {
997 if let Some(inequality_key) = inequality_key.as_ref() {
998 self.remove_pk_from_inequality_index(&key, inequality_key);
999 }
1000 self.kv_heap_size.sub(&key, &old_value);
1001 removed = true;
1002 }
1003 }
1004
1005 self.kv_heap_size.add(&key, &value);
1006
1007 if let Some(inequality_key) = inequality_key {
1008 self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1009 }
1010 let ret = self.cached.try_insert(key.clone(), value);
1011
1012 if !enable_strict_consistency() {
1013 assert!(ret.is_ok(), "we have removed existing entry, if any");
1014 if removed {
1015 consistency_error!(?key, "double inserting a join state entry");
1017 }
1018 }
1019
1020 ret.map_err(|_| JoinEntryError::Occupied)
1021 }
1022
1023 pub fn remove(
1025 &mut self,
1026 pk: PkType,
1027 inequality_key: Option<&InequalKeyType>,
1028 ) -> Result<(), JoinEntryError> {
1029 if let Some(value) = self.cached.remove(&pk) {
1030 self.kv_heap_size.sub(&pk, &value);
1031 if let Some(inequality_key) = inequality_key {
1032 self.remove_pk_from_inequality_index(&pk, inequality_key);
1033 }
1034 Ok(())
1035 } else if enable_strict_consistency() {
1036 Err(JoinEntryError::Remove)
1037 } else {
1038 consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1039 Ok(())
1040 }
1041 }
1042
1043 fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1044 if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1045 if pk_set.remove(pk).is_none() {
1046 if enable_strict_consistency() {
1047 panic!("removing a pk that it not in the inequality index");
1048 } else {
1049 consistency_error!(?pk, "removing a pk that it not in the inequality index");
1050 };
1051 } else {
1052 self.kv_heap_size.sub(pk, &());
1053 }
1054 if pk_set.is_empty() {
1055 self.inequality_index.remove(inequality_key);
1056 }
1057 }
1058 }
1059
1060 fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1061 if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1062 let pk_size = pk.estimated_size();
1063 if pk_set.try_insert(pk, ()).is_err() {
1064 if enable_strict_consistency() {
1065 panic!("inserting a pk that it already in the inequality index");
1066 } else {
1067 consistency_error!("inserting a pk that it already in the inequality index");
1068 };
1069 } else {
1070 self.kv_heap_size.add_size(pk_size);
1071 }
1072 } else {
1073 let mut pk_set = JoinRowSet::default();
1074 pk_set.try_insert(pk, ()).unwrap();
1075 self.inequality_index
1076 .try_insert(inequality_key, pk_set)
1077 .unwrap();
1078 }
1079 }
1080
1081 pub fn get(
1082 &self,
1083 pk: &PkType,
1084 data_types: &[DataType],
1085 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1086 self.cached
1087 .get(pk)
1088 .map(|encoded| encoded.decode(data_types))
1089 }
1090
1091 pub fn values_mut<'a>(
1097 &'a mut self,
1098 data_types: &'a [DataType],
1099 ) -> impl Iterator<
1100 Item = (
1101 &'a mut E::EncodedRow,
1102 StreamExecutorResult<JoinRow<OwnedRow>>,
1103 ),
1104 > + 'a {
1105 self.cached.values_mut().map(|encoded| {
1106 let decoded = encoded.decode(data_types);
1107 (encoded, decoded)
1108 })
1109 }
1110
1111 pub fn len(&self) -> usize {
1112 self.cached.len()
1113 }
1114
1115 pub fn range_by_inequality<'a, R>(
1117 &'a self,
1118 range: R,
1119 data_types: &'a [DataType],
1120 ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<OwnedRow>>> + 'a
1121 where
1122 R: RangeBounds<InequalKeyType> + 'a,
1123 {
1124 self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1125 pk_set
1126 .keys()
1127 .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1128 })
1129 }
1130
1131 pub fn upper_bound_by_inequality<'a>(
1133 &'a self,
1134 bound: Bound<&InequalKeyType>,
1135 data_types: &'a [DataType],
1136 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1137 if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1138 if let Some(pk) = pk_set.first_key_sorted() {
1139 self.get_by_indexed_pk(pk, data_types)
1140 } else {
1141 panic!("pk set for a index record must has at least one element");
1142 }
1143 } else {
1144 None
1145 }
1146 }
1147
1148 pub fn get_by_indexed_pk(
1149 &self,
1150 pk: &PkType,
1151 data_types: &[DataType],
1152 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>>
1153where {
1154 if let Some(value) = self.cached.get(pk) {
1155 Some(value.decode(data_types))
1156 } else if enable_strict_consistency() {
1157 Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1158 } else {
1159 consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1160 None
1161 }
1162 }
1163
1164 pub fn lower_bound_by_inequality<'a>(
1166 &'a self,
1167 bound: Bound<&InequalKeyType>,
1168 data_types: &'a [DataType],
1169 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1170 if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1171 if let Some(pk) = pk_set.first_key_sorted() {
1172 self.get_by_indexed_pk(pk, data_types)
1173 } else {
1174 panic!("pk set for a index record must has at least one element");
1175 }
1176 } else {
1177 None
1178 }
1179 }
1180
1181 pub fn get_first_by_inequality<'a>(
1182 &'a self,
1183 inequality_key: &InequalKeyType,
1184 data_types: &'a [DataType],
1185 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1186 if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1187 if let Some(pk) = pk_set.first_key_sorted() {
1188 self.get_by_indexed_pk(pk, data_types)
1189 } else {
1190 panic!("pk set for a index record must has at least one element");
1191 }
1192 } else {
1193 None
1194 }
1195 }
1196
1197 pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1198 &self.inequality_index
1199 }
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204 use itertools::Itertools;
1205 use risingwave_common::array::*;
1206 use risingwave_common::util::iter_util::ZipEqDebug;
1207
1208 use super::*;
1209 use crate::executor::MemoryEncoding;
1210
1211 fn insert_chunk<E: JoinEncoding>(
1212 managed_state: &mut JoinEntryState<E>,
1213 pk_indices: &[usize],
1214 col_types: &[DataType],
1215 inequality_key_idx: Option<usize>,
1216 data_chunk: &DataChunk,
1217 ) {
1218 let pk_col_type = pk_indices
1219 .iter()
1220 .map(|idx| col_types[*idx].clone())
1221 .collect_vec();
1222 let pk_serializer =
1223 OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1224 let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1225 let inequality_key_serializer = inequality_key_type
1226 .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1227 for row_ref in data_chunk.rows() {
1228 let row: OwnedRow = row_ref.into_owned_row();
1229 let value_indices = (0..row.len() - 1).collect_vec();
1230 let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1231 let pk = OwnedRow::new(pk)
1233 .project(&value_indices)
1234 .memcmp_serialize(&pk_serializer);
1235 let inequality_key = inequality_key_idx.map(|idx| {
1236 (&row)
1237 .project(&[idx])
1238 .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1239 });
1240 let join_row = JoinRow { row, degree: 0 };
1241 managed_state
1242 .insert(pk, E::encode(&join_row), inequality_key)
1243 .unwrap();
1244 }
1245 }
1246
1247 fn check<E: JoinEncoding>(
1248 managed_state: &mut JoinEntryState<E>,
1249 col_types: &[DataType],
1250 col1: &[i64],
1251 col2: &[i64],
1252 ) {
1253 for ((_, matched_row), (d1, d2)) in managed_state
1254 .values_mut(col_types)
1255 .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1256 {
1257 let matched_row = matched_row.unwrap();
1258 assert_eq!(matched_row.row[0], Some(ScalarImpl::Int64(*d1)));
1259 assert_eq!(matched_row.row[1], Some(ScalarImpl::Int64(*d2)));
1260 assert_eq!(matched_row.degree, 0);
1261 }
1262 }
1263
1264 #[tokio::test]
1265 async fn test_managed_join_state() {
1266 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1267 let col_types = vec![DataType::Int64, DataType::Int64];
1268 let pk_indices = [0];
1269
1270 let col1 = [3, 2, 1];
1271 let col2 = [4, 5, 6];
1272 let data_chunk1 = DataChunk::from_pretty(
1273 "I I
1274 3 4
1275 2 5
1276 1 6",
1277 );
1278
1279 insert_chunk::<MemoryEncoding>(
1281 &mut managed_state,
1282 &pk_indices,
1283 &col_types,
1284 None,
1285 &data_chunk1,
1286 );
1287 check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1288
1289 let col1 = [1, 2, 3, 4, 5];
1291 let col2 = [6, 5, 4, 9, 8];
1292 let data_chunk2 = DataChunk::from_pretty(
1293 "I I
1294 5 8
1295 4 9",
1296 );
1297 insert_chunk(
1298 &mut managed_state,
1299 &pk_indices,
1300 &col_types,
1301 None,
1302 &data_chunk2,
1303 );
1304 check(&mut managed_state, &col_types, &col1, &col2);
1305 }
1306
1307 #[tokio::test]
1308 async fn test_managed_join_state_w_inequality_index() {
1309 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1310 let col_types = vec![DataType::Int64, DataType::Int64];
1311 let pk_indices = [0];
1312 let inequality_key_idx = Some(1);
1313 let inequality_key_serializer =
1314 OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1315
1316 let col1 = [3, 2, 1];
1317 let col2 = [4, 5, 5];
1318 let data_chunk1 = DataChunk::from_pretty(
1319 "I I
1320 3 4
1321 2 5
1322 1 5",
1323 );
1324
1325 insert_chunk(
1327 &mut managed_state,
1328 &pk_indices,
1329 &col_types,
1330 inequality_key_idx,
1331 &data_chunk1,
1332 );
1333 check(&mut managed_state, &col_types, &col1, &col2);
1334 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1335 .memcmp_serialize(&inequality_key_serializer);
1336 let row = managed_state
1337 .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1338 .unwrap()
1339 .unwrap();
1340 assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1341 let row = managed_state
1342 .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1343 .unwrap()
1344 .unwrap();
1345 assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1346
1347 let col1 = [1, 2, 3, 4, 5];
1349 let col2 = [5, 5, 4, 4, 8];
1350 let data_chunk2 = DataChunk::from_pretty(
1351 "I I
1352 5 8
1353 4 4",
1354 );
1355 insert_chunk(
1356 &mut managed_state,
1357 &pk_indices,
1358 &col_types,
1359 inequality_key_idx,
1360 &data_chunk2,
1361 );
1362 check(&mut managed_state, &col_types, &col1, &col2);
1363
1364 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1365 .memcmp_serialize(&inequality_key_serializer);
1366 let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1367 assert!(row.is_none());
1368 }
1369}