1use std::alloc::Global;
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
25use risingwave_common::bitmap::Bitmap;
26use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
27use risingwave_common::metrics::LabelGuardedIntCounter;
28use risingwave_common::row::{OwnedRow, Row, RowExt, once};
29use risingwave_common::types::{DataType, ScalarImpl};
30use risingwave_common::util::epoch::EpochPair;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common::util::row_serde::OrderedRowSerde;
33use risingwave_common::util::sort_util::OrderType;
34use risingwave_common_estimate_size::EstimateSize;
35use risingwave_storage::StateStore;
36use risingwave_storage::store::PrefetchOptions;
37use risingwave_storage::table::KeyedRow;
38use thiserror_ext::AsReport;
39
40use super::row::{CachedJoinRow, DegreeType};
41use crate::cache::ManagedLruCache;
42use crate::common::metrics::MetricsInfo;
43use crate::common::table::state_table::{StateTable, StateTablePostCommit};
44use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
45use crate::executor::error::StreamExecutorResult;
46use crate::executor::join::row::JoinRow;
47use crate::executor::monitor::StreamingMetrics;
48use crate::executor::{JoinEncoding, StreamExecutorError};
49use crate::task::{ActorId, AtomicU64Ref, FragmentId};
50
51type PkType = Vec<u8>;
53type InequalKeyType = Vec<u8>;
54
55pub type HashValueType<E> = Box<JoinEntryState<E>>;
56
57impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
58 fn estimated_heap_size(&self) -> usize {
59 self.as_ref().estimated_heap_size()
60 }
61}
62
63struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
69
70pub(crate) enum CacheResult<E: JoinEncoding> {
71 NeverMatch, Miss, Hit(HashValueType<E>), }
75
76impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
77 fn estimated_heap_size(&self) -> usize {
78 self.0.estimated_heap_size()
79 }
80}
81
82impl<E: JoinEncoding> HashValueWrapper<E> {
83 const MESSAGE: &'static str = "the state should always be `Some`";
84
85 pub fn take(&mut self) -> HashValueType<E> {
87 self.0.take().expect(Self::MESSAGE)
88 }
89}
90
91impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
92 type Target = HashValueType<E>;
93
94 fn deref(&self) -> &Self::Target {
95 self.0.as_ref().expect(Self::MESSAGE)
96 }
97}
98
99impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
100 fn deref_mut(&mut self) -> &mut Self::Target {
101 self.0.as_mut().expect(Self::MESSAGE)
102 }
103}
104
105type JoinHashMapInner<K, E> =
106 ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher, SharedStatsAlloc<Global>>;
107
108pub struct JoinHashMapMetrics {
109 lookup_miss_count: usize,
112 total_lookup_count: usize,
113 insert_cache_miss_count: usize,
115
116 join_lookup_total_count_metric: LabelGuardedIntCounter,
118 join_lookup_miss_count_metric: LabelGuardedIntCounter,
119 join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
120}
121
122impl JoinHashMapMetrics {
123 pub fn new(
124 metrics: &StreamingMetrics,
125 actor_id: ActorId,
126 fragment_id: FragmentId,
127 side: &'static str,
128 join_table_id: u32,
129 ) -> Self {
130 let actor_id = actor_id.to_string();
131 let fragment_id = fragment_id.to_string();
132 let join_table_id = join_table_id.to_string();
133 let join_lookup_total_count_metric = metrics
134 .join_lookup_total_count
135 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
136 let join_lookup_miss_count_metric = metrics
137 .join_lookup_miss_count
138 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
139 let join_insert_cache_miss_count_metrics = metrics
140 .join_insert_cache_miss_count
141 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
142
143 Self {
144 lookup_miss_count: 0,
145 total_lookup_count: 0,
146 insert_cache_miss_count: 0,
147 join_lookup_total_count_metric,
148 join_lookup_miss_count_metric,
149 join_insert_cache_miss_count_metrics,
150 }
151 }
152
153 pub fn flush(&mut self) {
154 self.join_lookup_total_count_metric
155 .inc_by(self.total_lookup_count as u64);
156 self.join_lookup_miss_count_metric
157 .inc_by(self.lookup_miss_count as u64);
158 self.join_insert_cache_miss_count_metrics
159 .inc_by(self.insert_cache_miss_count as u64);
160 self.total_lookup_count = 0;
161 self.lookup_miss_count = 0;
162 self.insert_cache_miss_count = 0;
163 }
164}
165
166struct InequalityKeyDesc {
168 idx: usize,
169 serializer: OrderedRowSerde,
170}
171
172impl InequalityKeyDesc {
173 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
175 let indices = vec![self.idx];
176 let inequality_key = row.project(&indices);
177 inequality_key.memcmp_serialize(&self.serializer)
178 }
179}
180
181pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
182 inner: JoinHashMapInner<K, E>,
184 join_key_data_types: Vec<DataType>,
186 null_matched: K::Bitmap,
188 pk_serializer: OrderedRowSerde,
190 state: TableInner<S>,
192 degree_state: Option<TableInner<S>>,
223 need_degree_table: bool,
226 pk_contained_in_jk: bool,
228 inequality_key_desc: Option<InequalityKeyDesc>,
230 metrics: JoinHashMapMetrics,
232 _marker: std::marker::PhantomData<E>,
233}
234
235impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
236 pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
237 (&self.state.order_key_indices, &mut self.degree_state)
238 }
239
240 pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
247 &'a mut self,
248 key: &'a K,
249 ) -> StreamExecutorResult<(
250 impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
251 &'a [usize],
252 &'a mut Option<TableInner<S>>,
253 )> {
254 let degree_state = &mut self.degree_state;
255 let (order_key_indices, pk_indices, state_table) = (
256 &self.state.order_key_indices,
257 &self.state.pk_indices,
258 &mut self.state.table,
259 );
260 let degrees = if let Some(degree_state) = degree_state {
261 Some(fetch_degrees(key, &self.join_key_data_types, °ree_state.table).await?)
262 } else {
263 None
264 };
265 let stream = into_stream(
266 &self.join_key_data_types,
267 pk_indices,
268 &self.pk_serializer,
269 state_table,
270 key,
271 degrees,
272 );
273 Ok((stream, order_key_indices, &mut self.degree_state))
274 }
275}
276
277#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
278pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
279 join_key_data_types: &'a [DataType],
280 pk_indices: &'a [usize],
281 pk_serializer: &'a OrderedRowSerde,
282 state_table: &'a StateTable<S>,
283 key: &'a K,
284 degrees: Option<Vec<DegreeType>>,
285) {
286 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
287 let decoded_key = key.deserialize(join_key_data_types)?;
288 let table_iter = state_table
289 .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
290 .await?;
291
292 #[for_await]
293 for (i, entry) in table_iter.enumerate() {
294 let encoded_row = entry?;
295 let encoded_pk = encoded_row
296 .as_ref()
297 .project(pk_indices)
298 .memcmp_serialize(pk_serializer);
299 let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
300 yield (encoded_pk, join_row);
301 }
302}
303
304async fn fetch_degrees<K: HashKey, S: StateStore>(
328 key: &K,
329 join_key_data_types: &[DataType],
330 degree_state_table: &StateTable<S>,
331) -> StreamExecutorResult<Vec<DegreeType>> {
332 let key = key.deserialize(join_key_data_types)?;
333 let mut degrees = vec![];
334 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
335 let table_iter = degree_state_table
336 .iter_with_prefix(key, sub_range, PrefetchOptions::default())
337 .await?;
338 #[for_await]
339 for entry in table_iter {
340 let degree_row = entry?;
341 let degree_i64 = degree_row
342 .datum_at(degree_row.len() - 1)
343 .expect("degree should not be NULL");
344 degrees.push(degree_i64.into_int64() as u64);
345 }
346 Ok(degrees)
347}
348
349pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
353 order_key_indices: &[usize],
354 degree_state: &mut TableInner<S>,
355 matched_row: &mut JoinRow<impl Row>,
356) {
357 let old_degree_row = (&matched_row.row)
358 .project(order_key_indices)
359 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
360 if INCREMENT {
361 matched_row.degree += 1;
362 } else {
363 matched_row.degree -= 1;
365 }
366 let new_degree_row = (&matched_row.row)
367 .project(order_key_indices)
368 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
369 degree_state.table.update(old_degree_row, new_degree_row);
370}
371
372pub struct TableInner<S: StateStore> {
373 pk_indices: Vec<usize>,
375 join_key_indices: Vec<usize>,
377 order_key_indices: Vec<usize>,
382 pub(crate) table: StateTable<S>,
383}
384
385impl<S: StateStore> TableInner<S> {
386 pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
387 let order_key_indices = table.pk_indices().to_vec();
388 Self {
389 pk_indices,
390 join_key_indices,
391 order_key_indices,
392 table,
393 }
394 }
395
396 fn error_context(&self, row: &impl Row) -> String {
397 let pk = row.project(&self.pk_indices);
398 let jk = row.project(&self.join_key_indices);
399 format!(
400 "join key: {}, pk: {}, row: {}, state_table_id: {}",
401 jk.display(),
402 pk.display(),
403 row.display(),
404 self.table.table_id()
405 )
406 }
407}
408
409impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
410 #[allow(clippy::too_many_arguments)]
412 pub fn new(
413 watermark_sequence: AtomicU64Ref,
414 join_key_data_types: Vec<DataType>,
415 state_join_key_indices: Vec<usize>,
416 state_all_data_types: Vec<DataType>,
417 state_table: StateTable<S>,
418 state_pk_indices: Vec<usize>,
419 degree_state: Option<TableInner<S>>,
420 null_matched: K::Bitmap,
421 pk_contained_in_jk: bool,
422 inequality_key_idx: Option<usize>,
423 metrics: Arc<StreamingMetrics>,
424 actor_id: ActorId,
425 fragment_id: FragmentId,
426 side: &'static str,
427 ) -> Self {
428 let alloc = StatsAlloc::new(Global).shared();
429 let pk_data_types = state_pk_indices
431 .iter()
432 .map(|i| state_all_data_types[*i].clone())
433 .collect();
434 let pk_serializer = OrderedRowSerde::new(
435 pk_data_types,
436 vec![OrderType::ascending(); state_pk_indices.len()],
437 );
438
439 let inequality_key_desc = inequality_key_idx.map(|idx| {
440 let serializer = OrderedRowSerde::new(
441 vec![state_all_data_types[idx].clone()],
442 vec![OrderType::ascending()],
443 );
444 InequalityKeyDesc { idx, serializer }
445 });
446
447 let join_table_id = state_table.table_id();
448 let state = TableInner {
449 pk_indices: state_pk_indices,
450 join_key_indices: state_join_key_indices,
451 order_key_indices: state_table.pk_indices().to_vec(),
452 table: state_table,
453 };
454
455 let need_degree_table = degree_state.is_some();
456
457 let metrics_info = MetricsInfo::new(
458 metrics.clone(),
459 join_table_id,
460 actor_id,
461 format!("hash join {}", side),
462 );
463
464 let cache = ManagedLruCache::unbounded_with_hasher_in(
465 watermark_sequence,
466 metrics_info,
467 PrecomputedBuildHasher,
468 alloc,
469 );
470
471 Self {
472 inner: cache,
473 join_key_data_types,
474 null_matched,
475 pk_serializer,
476 state,
477 degree_state,
478 need_degree_table,
479 pk_contained_in_jk,
480 inequality_key_desc,
481 metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
482 _marker: std::marker::PhantomData,
483 }
484 }
485
486 pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
487 self.state.table.init_epoch(epoch).await?;
488 if let Some(degree_state) = &mut self.degree_state {
489 degree_state.table.init_epoch(epoch).await?;
490 }
491 Ok(())
492 }
493}
494
495impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
496 pub async fn post_yield_barrier(
497 self,
498 vnode_bitmap: Option<Arc<Bitmap>>,
499 ) -> StreamExecutorResult<Option<bool>> {
500 let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
501 if let Some(degree_state) = self.degree_state {
502 let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
503 }
504 let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
505 if cache_may_stale.unwrap_or(false) {
506 self.inner.clear();
507 }
508 Ok(cache_may_stale)
509 }
510}
511impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
512 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
513 self.state.table.update_watermark(watermark.clone());
515 if let Some(degree_state) = &mut self.degree_state {
516 degree_state.table.update_watermark(watermark);
517 }
518 }
519
520 pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
529 self.metrics.total_lookup_count += 1;
530 if self.inner.contains(key) {
531 tracing::trace!("hit cache for join key: {:?}", key);
532 let mut state = self.inner.peek_mut(key).expect("checked contains");
535 CacheResult::Hit(state.take())
536 } else {
537 tracing::trace!("miss cache for join key: {:?}", key);
538 CacheResult::Miss
539 }
540 }
541
542 pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
551 self.metrics.total_lookup_count += 1;
552 let state = if self.inner.contains(key) {
553 let mut state = self.inner.peek_mut(key).unwrap();
556 state.take()
557 } else {
558 self.metrics.lookup_miss_count += 1;
559 self.fetch_cached_state(key).await?.into()
560 };
561 Ok(state)
562 }
563
564 async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
567 let key = key.deserialize(&self.join_key_data_types)?;
568
569 let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
570
571 if self.need_degree_table {
572 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
573 &(Bound::Unbounded, Bound::Unbounded);
574 let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
575 &key,
576 sub_range,
577 PrefetchOptions::default(),
578 );
579 let degree_state = self.degree_state.as_ref().unwrap();
580 let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
581 &key,
582 sub_range,
583 PrefetchOptions::default(),
584 );
585
586 let (table_iter, degree_table_iter) =
587 try_join(table_iter_fut, degree_table_iter_fut).await?;
588
589 let mut pinned_table_iter = std::pin::pin!(table_iter);
590 let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
591
592 let mut rows = vec![];
595 let mut degree_rows = vec![];
596 let mut inconsistency_happened = false;
597 loop {
598 let (row, degree_row) =
599 join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
600 let (row, degree_row) = match (row, degree_row) {
601 (None, None) => break,
602 (None, Some(_)) => {
603 inconsistency_happened = true;
604 consistency_panic!(
605 "mismatched row and degree table of join key: {:?}, degree table has more rows",
606 &key
607 );
608 break;
609 }
610 (Some(_), None) => {
611 inconsistency_happened = true;
612 consistency_panic!(
613 "mismatched row and degree table of join key: {:?}, input table has more rows",
614 &key
615 );
616 break;
617 }
618 (Some(r), Some(d)) => (r, d),
619 };
620
621 let row = row?;
622 let degree_row = degree_row?;
623 rows.push(row);
624 degree_rows.push(degree_row);
625 }
626
627 if inconsistency_happened {
628 assert_ne!(rows.len(), degree_rows.len());
630
631 let row_iter = stream::iter(rows.into_iter()).peekable();
632 let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
633 pin_mut!(row_iter);
634 pin_mut!(degree_row_iter);
635
636 loop {
637 match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
638 (None, _) | (_, None) => break,
639 (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
640 Ordering::Greater => {
641 degree_row_iter.next().await;
642 }
643 Ordering::Less => {
644 row_iter.next().await;
645 }
646 Ordering::Equal => {
647 let row =
648 row_iter.next().await.expect("we matched some(row) above");
649 let degree_row = degree_row_iter
650 .next()
651 .await
652 .expect("we matched some(degree_row) above");
653 let pk = row
654 .as_ref()
655 .project(&self.state.pk_indices)
656 .memcmp_serialize(&self.pk_serializer);
657 let degree_i64 = degree_row
658 .datum_at(degree_row.len() - 1)
659 .expect("degree should not be NULL");
660 let inequality_key = self
661 .inequality_key_desc
662 .as_ref()
663 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
664 entry_state
665 .insert(
666 pk,
667 E::encode(&JoinRow::new(
668 row.row(),
669 degree_i64.into_int64() as u64,
670 )),
671 inequality_key,
672 )
673 .with_context(|| self.state.error_context(row.row()))?;
674 }
675 },
676 }
677 }
678 } else {
679 assert_eq!(rows.len(), degree_rows.len());
684
685 #[for_await]
686 for (row, degree_row) in
687 stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
688 {
689 let row: KeyedRow<_> = row;
690 let degree_row: KeyedRow<_> = degree_row;
691
692 let pk1 = row.key();
693 let pk2 = degree_row.key();
694 debug_assert_eq!(
695 pk1, pk2,
696 "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
697 );
698 let pk = row
699 .as_ref()
700 .project(&self.state.pk_indices)
701 .memcmp_serialize(&self.pk_serializer);
702 let inequality_key = self
703 .inequality_key_desc
704 .as_ref()
705 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
706 let degree_i64 = degree_row
707 .datum_at(degree_row.len() - 1)
708 .expect("degree should not be NULL");
709 entry_state
710 .insert(
711 pk,
712 E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
713 inequality_key,
714 )
715 .with_context(|| self.state.error_context(row.row()))?;
716 }
717 }
718 } else {
719 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
720 &(Bound::Unbounded, Bound::Unbounded);
721 let table_iter = self
722 .state
723 .table
724 .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
725 .await?;
726
727 #[for_await]
728 for entry in table_iter {
729 let row: KeyedRow<_> = entry?;
730 let pk = row
731 .as_ref()
732 .project(&self.state.pk_indices)
733 .memcmp_serialize(&self.pk_serializer);
734 let inequality_key = self
735 .inequality_key_desc
736 .as_ref()
737 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
738 entry_state
739 .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
740 .with_context(|| self.state.error_context(row.row()))?;
741 }
742 };
743
744 Ok(entry_state)
745 }
746
747 pub async fn flush(
748 &mut self,
749 epoch: EpochPair,
750 ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
751 self.metrics.flush();
752 let state_post_commit = self.state.table.commit(epoch).await?;
753 let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
754 Some(degree_state.table.commit(epoch).await?)
755 } else {
756 None
757 };
758 Ok(JoinHashMapPostCommit {
759 state: state_post_commit,
760 degree_state: degree_state_post_commit,
761 inner: &mut self.inner,
762 })
763 }
764
765 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
766 self.state.table.try_flush().await?;
767 if let Some(degree_state) = &mut self.degree_state {
768 degree_state.table.try_flush().await?;
769 }
770 Ok(())
771 }
772
773 pub fn insert_handle_degree(
774 &mut self,
775 key: &K,
776 value: JoinRow<impl Row>,
777 ) -> StreamExecutorResult<()> {
778 if self.need_degree_table {
779 self.insert(key, value)
780 } else {
781 self.insert_row(key, value.row)
782 }
783 }
784
785 pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
787 let pk = self.serialize_pk_from_row(&value.row);
788
789 let inequality_key = self
790 .inequality_key_desc
791 .as_ref()
792 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
793
794 if self.inner.contains(key) {
797 let mut entry = self.inner.get_mut(key).expect("checked contains");
799 entry
800 .insert(pk, E::encode(&value), inequality_key)
801 .with_context(|| self.state.error_context(&value.row))?;
802 } else if self.pk_contained_in_jk {
803 self.metrics.insert_cache_miss_count += 1;
805 let mut entry: JoinEntryState<E> = JoinEntryState::default();
806 entry
807 .insert(pk, E::encode(&value), inequality_key)
808 .with_context(|| self.state.error_context(&value.row))?;
809 self.update_state(key, entry.into());
810 }
811
812 if let Some(degree_state) = self.degree_state.as_mut() {
814 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
815 self.state.table.insert(row);
816 degree_state.table.insert(degree);
817 } else {
818 self.state.table.insert(value.row);
819 }
820 Ok(())
821 }
822
823 pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
826 let join_row = JoinRow::new(&value, 0);
827 self.insert(key, join_row)?;
828 Ok(())
829 }
830
831 pub fn delete_handle_degree(
832 &mut self,
833 key: &K,
834 value: JoinRow<impl Row>,
835 ) -> StreamExecutorResult<()> {
836 if self.need_degree_table {
837 self.delete(key, value)
838 } else {
839 self.delete_row(key, value.row)
840 }
841 }
842
843 pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
845 if let Some(mut entry) = self.inner.get_mut(key) {
846 let pk = (&value.row)
847 .project(&self.state.pk_indices)
848 .memcmp_serialize(&self.pk_serializer);
849 let inequality_key = self
850 .inequality_key_desc
851 .as_ref()
852 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
853 entry
854 .remove(pk, inequality_key.as_ref())
855 .with_context(|| self.state.error_context(&value.row))?;
856 }
857
858 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
860 self.state.table.delete(row);
861 let degree_state = self.degree_state.as_mut().expect("degree table missing");
862 degree_state.table.delete(degree);
863 Ok(())
864 }
865
866 pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
869 if let Some(mut entry) = self.inner.get_mut(key) {
870 let pk = (&value)
871 .project(&self.state.pk_indices)
872 .memcmp_serialize(&self.pk_serializer);
873
874 let inequality_key = self
875 .inequality_key_desc
876 .as_ref()
877 .map(|desc| desc.serialize_inequal_key_from_row(&value));
878 entry
879 .remove(pk, inequality_key.as_ref())
880 .with_context(|| self.state.error_context(&value))?;
881 }
882
883 self.state.table.delete(value);
885 Ok(())
886 }
887
888 pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
890 self.inner.put(key.clone(), HashValueWrapper(Some(state)));
891 }
892
893 pub fn evict(&mut self) {
895 self.inner.evict();
896 }
897
898 pub fn entry_count(&self) -> usize {
900 self.inner.len()
901 }
902
903 pub fn null_matched(&self) -> &K::Bitmap {
904 &self.null_matched
905 }
906
907 pub fn table_id(&self) -> u32 {
908 self.state.table.table_id()
909 }
910
911 pub fn join_key_data_types(&self) -> &[DataType] {
912 &self.join_key_data_types
913 }
914
915 pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
919 let desc = self
920 .inequality_key_desc
921 .as_ref()
922 .expect("inequality key desc missing");
923 row.datum_at(desc.idx).is_none()
924 }
925
926 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
930 self.inequality_key_desc
931 .as_ref()
932 .expect("inequality key desc missing")
933 .serialize_inequal_key_from_row(&row)
934 }
935
936 pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
937 row.project(&self.state.pk_indices)
938 .memcmp_serialize(&self.pk_serializer)
939 }
940}
941
942#[must_use]
943pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
944 state: StateTablePostCommit<'a, S>,
945 degree_state: Option<StateTablePostCommit<'a, S>>,
946 inner: &'a mut JoinHashMapInner<K, E>,
947}
948
949use risingwave_common_estimate_size::KvSize;
950use thiserror::Error;
951
952use super::*;
953use crate::executor::prelude::{Stream, try_stream};
954
955#[derive(Default)]
961pub struct JoinEntryState<E: JoinEncoding> {
962 cached: JoinRowSet<PkType, E::EncodedRow>,
964 inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
966 kv_heap_size: KvSize,
967}
968
969impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
970 fn estimated_heap_size(&self) -> usize {
971 self.kv_heap_size.size()
974 }
975}
976
977#[derive(Error, Debug)]
978pub enum JoinEntryError {
979 #[error("double inserting a join state entry")]
980 Occupied,
981 #[error("removing a join state entry but it is not in the cache")]
982 Remove,
983 #[error("retrieving a pk from the inequality index but it is not in the cache")]
984 InequalIndex,
985}
986
987impl<E: JoinEncoding> JoinEntryState<E> {
988 pub fn insert(
990 &mut self,
991 key: PkType,
992 value: E::EncodedRow,
993 inequality_key: Option<InequalKeyType>,
994 ) -> Result<&mut E::EncodedRow, JoinEntryError> {
995 let mut removed = false;
996 if !enable_strict_consistency() {
997 if let Some(old_value) = self.cached.remove(&key) {
999 if let Some(inequality_key) = inequality_key.as_ref() {
1000 self.remove_pk_from_inequality_index(&key, inequality_key);
1001 }
1002 self.kv_heap_size.sub(&key, &old_value);
1003 removed = true;
1004 }
1005 }
1006
1007 self.kv_heap_size.add(&key, &value);
1008
1009 if let Some(inequality_key) = inequality_key {
1010 self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1011 }
1012 let ret = self.cached.try_insert(key.clone(), value);
1013
1014 if !enable_strict_consistency() {
1015 assert!(ret.is_ok(), "we have removed existing entry, if any");
1016 if removed {
1017 consistency_error!(?key, "double inserting a join state entry");
1019 }
1020 }
1021
1022 ret.map_err(|_| JoinEntryError::Occupied)
1023 }
1024
1025 pub fn remove(
1027 &mut self,
1028 pk: PkType,
1029 inequality_key: Option<&InequalKeyType>,
1030 ) -> Result<(), JoinEntryError> {
1031 if let Some(value) = self.cached.remove(&pk) {
1032 self.kv_heap_size.sub(&pk, &value);
1033 if let Some(inequality_key) = inequality_key {
1034 self.remove_pk_from_inequality_index(&pk, inequality_key);
1035 }
1036 Ok(())
1037 } else if enable_strict_consistency() {
1038 Err(JoinEntryError::Remove)
1039 } else {
1040 consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1041 Ok(())
1042 }
1043 }
1044
1045 fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1046 if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1047 if pk_set.remove(pk).is_none() {
1048 if enable_strict_consistency() {
1049 panic!("removing a pk that it not in the inequality index");
1050 } else {
1051 consistency_error!(?pk, "removing a pk that it not in the inequality index");
1052 };
1053 } else {
1054 self.kv_heap_size.sub(pk, &());
1055 }
1056 if pk_set.is_empty() {
1057 self.inequality_index.remove(inequality_key);
1058 }
1059 }
1060 }
1061
1062 fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1063 if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1064 let pk_size = pk.estimated_size();
1065 if pk_set.try_insert(pk, ()).is_err() {
1066 if enable_strict_consistency() {
1067 panic!("inserting a pk that it already in the inequality index");
1068 } else {
1069 consistency_error!("inserting a pk that it already in the inequality index");
1070 };
1071 } else {
1072 self.kv_heap_size.add_size(pk_size);
1073 }
1074 } else {
1075 let mut pk_set = JoinRowSet::default();
1076 pk_set.try_insert(pk, ()).expect("pk set should be empty");
1077 self.inequality_index
1078 .try_insert(inequality_key, pk_set)
1079 .expect("pk set should be empty");
1080 }
1081 }
1082
1083 pub fn get(
1084 &self,
1085 pk: &PkType,
1086 data_types: &[DataType],
1087 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1088 self.cached
1089 .get(pk)
1090 .map(|encoded| encoded.decode(data_types))
1091 }
1092
1093 pub fn values_mut<'a>(
1099 &'a mut self,
1100 data_types: &'a [DataType],
1101 ) -> impl Iterator<
1102 Item = (
1103 &'a mut E::EncodedRow,
1104 StreamExecutorResult<JoinRow<E::DecodedRow>>,
1105 ),
1106 > + 'a {
1107 self.cached.values_mut().map(|encoded| {
1108 let decoded = encoded.decode(data_types);
1109 (encoded, decoded)
1110 })
1111 }
1112
1113 pub fn len(&self) -> usize {
1114 self.cached.len()
1115 }
1116
1117 pub fn range_by_inequality<'a, R>(
1119 &'a self,
1120 range: R,
1121 data_types: &'a [DataType],
1122 ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<E::DecodedRow>>> + 'a
1123 where
1124 R: RangeBounds<InequalKeyType> + 'a,
1125 {
1126 self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1127 pk_set
1128 .keys()
1129 .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1130 })
1131 }
1132
1133 pub fn upper_bound_by_inequality<'a>(
1135 &'a self,
1136 bound: Bound<&InequalKeyType>,
1137 data_types: &'a [DataType],
1138 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1139 if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1140 if let Some(pk) = pk_set.first_key_sorted() {
1141 self.get_by_indexed_pk(pk, data_types)
1142 } else {
1143 panic!("pk set for a index record must has at least one element");
1144 }
1145 } else {
1146 None
1147 }
1148 }
1149
1150 pub fn get_by_indexed_pk(
1151 &self,
1152 pk: &PkType,
1153 data_types: &[DataType],
1154 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>>
1155where {
1156 if let Some(value) = self.cached.get(pk) {
1157 Some(value.decode(data_types))
1158 } else if enable_strict_consistency() {
1159 Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1160 } else {
1161 consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1162 None
1163 }
1164 }
1165
1166 pub fn lower_bound_by_inequality<'a>(
1168 &'a self,
1169 bound: Bound<&InequalKeyType>,
1170 data_types: &'a [DataType],
1171 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1172 if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1173 if let Some(pk) = pk_set.first_key_sorted() {
1174 self.get_by_indexed_pk(pk, data_types)
1175 } else {
1176 panic!("pk set for a index record must has at least one element");
1177 }
1178 } else {
1179 None
1180 }
1181 }
1182
1183 pub fn get_first_by_inequality<'a>(
1184 &'a self,
1185 inequality_key: &InequalKeyType,
1186 data_types: &'a [DataType],
1187 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1188 if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1189 if let Some(pk) = pk_set.first_key_sorted() {
1190 self.get_by_indexed_pk(pk, data_types)
1191 } else {
1192 panic!("pk set for a index record must has at least one element");
1193 }
1194 } else {
1195 None
1196 }
1197 }
1198
1199 pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1200 &self.inequality_index
1201 }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206 use itertools::Itertools;
1207 use risingwave_common::array::*;
1208 use risingwave_common::types::ScalarRefImpl;
1209 use risingwave_common::util::iter_util::ZipEqDebug;
1210
1211 use super::*;
1212 use crate::executor::MemoryEncoding;
1213
1214 fn insert_chunk<E: JoinEncoding>(
1215 managed_state: &mut JoinEntryState<E>,
1216 pk_indices: &[usize],
1217 col_types: &[DataType],
1218 inequality_key_idx: Option<usize>,
1219 data_chunk: &DataChunk,
1220 ) {
1221 let pk_col_type = pk_indices
1222 .iter()
1223 .map(|idx| col_types[*idx].clone())
1224 .collect_vec();
1225 let pk_serializer =
1226 OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1227 let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1228 let inequality_key_serializer = inequality_key_type
1229 .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1230 for row_ref in data_chunk.rows() {
1231 let row: OwnedRow = row_ref.into_owned_row();
1232 let value_indices = (0..row.len() - 1).collect_vec();
1233 let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1234 let pk = OwnedRow::new(pk)
1236 .project(&value_indices)
1237 .memcmp_serialize(&pk_serializer);
1238 let inequality_key = inequality_key_idx.map(|idx| {
1239 (&row)
1240 .project(&[idx])
1241 .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1242 });
1243 let join_row = JoinRow { row, degree: 0 };
1244 managed_state
1245 .insert(pk, E::encode(&join_row), inequality_key)
1246 .unwrap();
1247 }
1248 }
1249
1250 fn check<E: JoinEncoding>(
1251 managed_state: &mut JoinEntryState<E>,
1252 col_types: &[DataType],
1253 col1: &[i64],
1254 col2: &[i64],
1255 ) {
1256 for ((_, matched_row), (d1, d2)) in managed_state
1257 .values_mut(col_types)
1258 .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1259 {
1260 let matched_row = matched_row.unwrap();
1261 assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
1262 assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
1263 assert_eq!(matched_row.degree, 0);
1264 }
1265 }
1266
1267 #[tokio::test]
1268 async fn test_managed_join_state() {
1269 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1270 let col_types = vec![DataType::Int64, DataType::Int64];
1271 let pk_indices = [0];
1272
1273 let col1 = [3, 2, 1];
1274 let col2 = [4, 5, 6];
1275 let data_chunk1 = DataChunk::from_pretty(
1276 "I I
1277 3 4
1278 2 5
1279 1 6",
1280 );
1281
1282 insert_chunk::<MemoryEncoding>(
1284 &mut managed_state,
1285 &pk_indices,
1286 &col_types,
1287 None,
1288 &data_chunk1,
1289 );
1290 check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1291
1292 let col1 = [1, 2, 3, 4, 5];
1294 let col2 = [6, 5, 4, 9, 8];
1295 let data_chunk2 = DataChunk::from_pretty(
1296 "I I
1297 5 8
1298 4 9",
1299 );
1300 insert_chunk(
1301 &mut managed_state,
1302 &pk_indices,
1303 &col_types,
1304 None,
1305 &data_chunk2,
1306 );
1307 check(&mut managed_state, &col_types, &col1, &col2);
1308 }
1309
1310 #[tokio::test]
1311 async fn test_managed_join_state_w_inequality_index() {
1312 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1313 let col_types = vec![DataType::Int64, DataType::Int64];
1314 let pk_indices = [0];
1315 let inequality_key_idx = Some(1);
1316 let inequality_key_serializer =
1317 OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1318
1319 let col1 = [3, 2, 1];
1320 let col2 = [4, 5, 5];
1321 let data_chunk1 = DataChunk::from_pretty(
1322 "I I
1323 3 4
1324 2 5
1325 1 5",
1326 );
1327
1328 insert_chunk(
1330 &mut managed_state,
1331 &pk_indices,
1332 &col_types,
1333 inequality_key_idx,
1334 &data_chunk1,
1335 );
1336 check(&mut managed_state, &col_types, &col1, &col2);
1337 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1338 .memcmp_serialize(&inequality_key_serializer);
1339 let row = managed_state
1340 .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1341 .unwrap()
1342 .unwrap();
1343 assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1344 let row = managed_state
1345 .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1346 .unwrap()
1347 .unwrap();
1348 assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1349
1350 let col1 = [1, 2, 3, 4, 5];
1352 let col2 = [5, 5, 4, 4, 8];
1353 let data_chunk2 = DataChunk::from_pretty(
1354 "I I
1355 5 8
1356 4 4",
1357 );
1358 insert_chunk(
1359 &mut managed_state,
1360 &pk_indices,
1361 &col_types,
1362 inequality_key_idx,
1363 &data_chunk2,
1364 );
1365 check(&mut managed_state, &col_types, &col1, &col2);
1366
1367 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1368 .memcmp_serialize(&inequality_key_serializer);
1369 let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1370 assert!(row.is_none());
1371 }
1372}