1use std::alloc::Global;
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
25use risingwave_common::bitmap::Bitmap;
26use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
27use risingwave_common::metrics::LabelGuardedIntCounter;
28use risingwave_common::row::{OwnedRow, Row, RowExt, once};
29use risingwave_common::types::{DataType, ScalarImpl};
30use risingwave_common::util::epoch::EpochPair;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common::util::row_serde::OrderedRowSerde;
33use risingwave_common::util::sort_util::OrderType;
34use risingwave_common_estimate_size::EstimateSize;
35use risingwave_storage::StateStore;
36use risingwave_storage::store::PrefetchOptions;
37use thiserror_ext::AsReport;
38
39use super::row::{CachedJoinRow, DegreeType};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::error::StreamExecutorResult;
45use crate::executor::join::row::JoinRow;
46use crate::executor::monitor::StreamingMetrics;
47use crate::executor::{JoinEncoding, StreamExecutorError};
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type HashValueType<E> = Box<JoinEntryState<E>>;
55
56impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
57 fn estimated_heap_size(&self) -> usize {
58 self.as_ref().estimated_heap_size()
59 }
60}
61
62struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
68
69pub(crate) enum CacheResult<E: JoinEncoding> {
70 NeverMatch, Miss, Hit(HashValueType<E>), }
74
75impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
76 fn estimated_heap_size(&self) -> usize {
77 self.0.estimated_heap_size()
78 }
79}
80
81impl<E: JoinEncoding> HashValueWrapper<E> {
82 const MESSAGE: &'static str = "the state should always be `Some`";
83
84 pub fn take(&mut self) -> HashValueType<E> {
86 self.0.take().expect(Self::MESSAGE)
87 }
88}
89
90impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
91 type Target = HashValueType<E>;
92
93 fn deref(&self) -> &Self::Target {
94 self.0.as_ref().expect(Self::MESSAGE)
95 }
96}
97
98impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
99 fn deref_mut(&mut self) -> &mut Self::Target {
100 self.0.as_mut().expect(Self::MESSAGE)
101 }
102}
103
104type JoinHashMapInner<K, E> =
105 ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher, SharedStatsAlloc<Global>>;
106
107pub struct JoinHashMapMetrics {
108 lookup_miss_count: usize,
111 total_lookup_count: usize,
112 insert_cache_miss_count: usize,
114
115 join_lookup_total_count_metric: LabelGuardedIntCounter,
117 join_lookup_miss_count_metric: LabelGuardedIntCounter,
118 join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
119}
120
121impl JoinHashMapMetrics {
122 pub fn new(
123 metrics: &StreamingMetrics,
124 actor_id: ActorId,
125 fragment_id: FragmentId,
126 side: &'static str,
127 join_table_id: u32,
128 ) -> Self {
129 let actor_id = actor_id.to_string();
130 let fragment_id = fragment_id.to_string();
131 let join_table_id = join_table_id.to_string();
132 let join_lookup_total_count_metric = metrics
133 .join_lookup_total_count
134 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
135 let join_lookup_miss_count_metric = metrics
136 .join_lookup_miss_count
137 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
138 let join_insert_cache_miss_count_metrics = metrics
139 .join_insert_cache_miss_count
140 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
141
142 Self {
143 lookup_miss_count: 0,
144 total_lookup_count: 0,
145 insert_cache_miss_count: 0,
146 join_lookup_total_count_metric,
147 join_lookup_miss_count_metric,
148 join_insert_cache_miss_count_metrics,
149 }
150 }
151
152 pub fn flush(&mut self) {
153 self.join_lookup_total_count_metric
154 .inc_by(self.total_lookup_count as u64);
155 self.join_lookup_miss_count_metric
156 .inc_by(self.lookup_miss_count as u64);
157 self.join_insert_cache_miss_count_metrics
158 .inc_by(self.insert_cache_miss_count as u64);
159 self.total_lookup_count = 0;
160 self.lookup_miss_count = 0;
161 self.insert_cache_miss_count = 0;
162 }
163}
164
165struct InequalityKeyDesc {
167 idx: usize,
168 serializer: OrderedRowSerde,
169}
170
171impl InequalityKeyDesc {
172 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
174 let indices = vec![self.idx];
175 let inequality_key = row.project(&indices);
176 inequality_key.memcmp_serialize(&self.serializer)
177 }
178}
179
180pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
181 inner: JoinHashMapInner<K, E>,
183 join_key_data_types: Vec<DataType>,
185 null_matched: K::Bitmap,
187 pk_serializer: OrderedRowSerde,
189 state: TableInner<S>,
191 degree_state: Option<TableInner<S>>,
222 need_degree_table: bool,
225 pk_contained_in_jk: bool,
227 inequality_key_desc: Option<InequalityKeyDesc>,
229 metrics: JoinHashMapMetrics,
231 _marker: std::marker::PhantomData<E>,
232}
233
234impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
235 pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
236 (&self.state.order_key_indices, &mut self.degree_state)
237 }
238
239 pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
246 &'a mut self,
247 key: &'a K,
248 ) -> StreamExecutorResult<(
249 impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
250 &'a [usize],
251 &'a mut Option<TableInner<S>>,
252 )> {
253 let degree_state = &mut self.degree_state;
254 let (order_key_indices, pk_indices, state_table) = (
255 &self.state.order_key_indices,
256 &self.state.pk_indices,
257 &mut self.state.table,
258 );
259 let degrees = if let Some(degree_state) = degree_state {
260 Some(fetch_degrees(key, &self.join_key_data_types, °ree_state.table).await?)
261 } else {
262 None
263 };
264 let stream = into_stream(
265 &self.join_key_data_types,
266 pk_indices,
267 &self.pk_serializer,
268 state_table,
269 key,
270 degrees,
271 );
272 Ok((stream, order_key_indices, &mut self.degree_state))
273 }
274}
275
276#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
277pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
278 join_key_data_types: &'a [DataType],
279 pk_indices: &'a [usize],
280 pk_serializer: &'a OrderedRowSerde,
281 state_table: &'a StateTable<S>,
282 key: &'a K,
283 degrees: Option<Vec<DegreeType>>,
284) {
285 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
286 let decoded_key = key.deserialize(join_key_data_types)?;
287 let table_iter = state_table
288 .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
289 .await?;
290
291 #[for_await]
292 for (i, entry) in table_iter.enumerate() {
293 let encoded_row = entry?;
294 let encoded_pk = encoded_row
295 .as_ref()
296 .project(pk_indices)
297 .memcmp_serialize(pk_serializer);
298 let join_row = JoinRow::new(
299 encoded_row.into_owned_row(),
300 degrees.as_ref().map_or(0, |d| d[i]),
301 );
302 yield (encoded_pk, join_row);
303 }
304}
305
306async fn fetch_degrees<K: HashKey, S: StateStore>(
330 key: &K,
331 join_key_data_types: &[DataType],
332 degree_state_table: &StateTable<S>,
333) -> StreamExecutorResult<Vec<DegreeType>> {
334 let key = key.deserialize(join_key_data_types)?;
335 let mut degrees = vec![];
336 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
337 let table_iter = degree_state_table
338 .iter_with_prefix(key, sub_range, PrefetchOptions::default())
339 .await?;
340 #[for_await]
341 for entry in table_iter {
342 let degree_row = entry?;
343 let degree_i64 = degree_row
344 .datum_at(degree_row.len() - 1)
345 .expect("degree should not be NULL");
346 degrees.push(degree_i64.into_int64() as u64);
347 }
348 Ok(degrees)
349}
350
351pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
355 order_key_indices: &[usize],
356 degree_state: &mut TableInner<S>,
357 matched_row: &mut JoinRow<OwnedRow>,
358) {
359 let old_degree_row = matched_row
360 .row
361 .as_ref()
362 .project(order_key_indices)
363 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
364 if INCREMENT {
365 matched_row.degree += 1;
366 } else {
367 matched_row.degree -= 1;
369 }
370 let new_degree_row = matched_row
371 .row
372 .as_ref()
373 .project(order_key_indices)
374 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
375 degree_state.table.update(old_degree_row, new_degree_row);
376}
377
378pub struct TableInner<S: StateStore> {
379 pk_indices: Vec<usize>,
381 join_key_indices: Vec<usize>,
383 order_key_indices: Vec<usize>,
388 pub(crate) table: StateTable<S>,
389}
390
391impl<S: StateStore> TableInner<S> {
392 pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
393 let order_key_indices = table.pk_indices().to_vec();
394 Self {
395 pk_indices,
396 join_key_indices,
397 order_key_indices,
398 table,
399 }
400 }
401
402 fn error_context(&self, row: &impl Row) -> String {
403 let pk = row.project(&self.pk_indices);
404 let jk = row.project(&self.join_key_indices);
405 format!(
406 "join key: {}, pk: {}, row: {}, state_table_id: {}",
407 jk.display(),
408 pk.display(),
409 row.display(),
410 self.table.table_id()
411 )
412 }
413}
414
415impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
416 #[allow(clippy::too_many_arguments)]
418 pub fn new(
419 watermark_sequence: AtomicU64Ref,
420 join_key_data_types: Vec<DataType>,
421 state_join_key_indices: Vec<usize>,
422 state_all_data_types: Vec<DataType>,
423 state_table: StateTable<S>,
424 state_pk_indices: Vec<usize>,
425 degree_state: Option<TableInner<S>>,
426 null_matched: K::Bitmap,
427 pk_contained_in_jk: bool,
428 inequality_key_idx: Option<usize>,
429 metrics: Arc<StreamingMetrics>,
430 actor_id: ActorId,
431 fragment_id: FragmentId,
432 side: &'static str,
433 ) -> Self {
434 let alloc = StatsAlloc::new(Global).shared();
435 let pk_data_types = state_pk_indices
437 .iter()
438 .map(|i| state_all_data_types[*i].clone())
439 .collect();
440 let pk_serializer = OrderedRowSerde::new(
441 pk_data_types,
442 vec![OrderType::ascending(); state_pk_indices.len()],
443 );
444
445 let inequality_key_desc = inequality_key_idx.map(|idx| {
446 let serializer = OrderedRowSerde::new(
447 vec![state_all_data_types[idx].clone()],
448 vec![OrderType::ascending()],
449 );
450 InequalityKeyDesc { idx, serializer }
451 });
452
453 let join_table_id = state_table.table_id();
454 let state = TableInner {
455 pk_indices: state_pk_indices,
456 join_key_indices: state_join_key_indices,
457 order_key_indices: state_table.pk_indices().to_vec(),
458 table: state_table,
459 };
460
461 let need_degree_table = degree_state.is_some();
462
463 let metrics_info = MetricsInfo::new(
464 metrics.clone(),
465 join_table_id,
466 actor_id,
467 format!("hash join {}", side),
468 );
469
470 let cache = ManagedLruCache::unbounded_with_hasher_in(
471 watermark_sequence,
472 metrics_info,
473 PrecomputedBuildHasher,
474 alloc,
475 );
476
477 Self {
478 inner: cache,
479 join_key_data_types,
480 null_matched,
481 pk_serializer,
482 state,
483 degree_state,
484 need_degree_table,
485 pk_contained_in_jk,
486 inequality_key_desc,
487 metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
488 _marker: std::marker::PhantomData,
489 }
490 }
491
492 pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
493 self.state.table.init_epoch(epoch).await?;
494 if let Some(degree_state) = &mut self.degree_state {
495 degree_state.table.init_epoch(epoch).await?;
496 }
497 Ok(())
498 }
499}
500
501impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
502 pub async fn post_yield_barrier(
503 self,
504 vnode_bitmap: Option<Arc<Bitmap>>,
505 ) -> StreamExecutorResult<Option<bool>> {
506 let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
507 if let Some(degree_state) = self.degree_state {
508 let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
509 }
510 let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
511 if cache_may_stale.unwrap_or(false) {
512 self.inner.clear();
513 }
514 Ok(cache_may_stale)
515 }
516}
517impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
518 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
519 self.state.table.update_watermark(watermark.clone());
521 if let Some(degree_state) = &mut self.degree_state {
522 degree_state.table.update_watermark(watermark);
523 }
524 }
525
526 pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
535 self.metrics.total_lookup_count += 1;
536 if self.inner.contains(key) {
537 tracing::trace!("hit cache for join key: {:?}", key);
538 let mut state = self.inner.peek_mut(key).expect("checked contains");
541 CacheResult::Hit(state.take())
542 } else {
543 tracing::trace!("miss cache for join key: {:?}", key);
544 CacheResult::Miss
545 }
546 }
547
548 pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
557 self.metrics.total_lookup_count += 1;
558 let state = if self.inner.contains(key) {
559 let mut state = self.inner.peek_mut(key).unwrap();
562 state.take()
563 } else {
564 self.metrics.lookup_miss_count += 1;
565 self.fetch_cached_state(key).await?.into()
566 };
567 Ok(state)
568 }
569
570 async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
573 let key = key.deserialize(&self.join_key_data_types)?;
574
575 let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
576
577 if self.need_degree_table {
578 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
579 &(Bound::Unbounded, Bound::Unbounded);
580 let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
581 &key,
582 sub_range,
583 PrefetchOptions::default(),
584 );
585 let degree_state = self.degree_state.as_ref().unwrap();
586 let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
587 &key,
588 sub_range,
589 PrefetchOptions::default(),
590 );
591
592 let (table_iter, degree_table_iter) =
593 try_join(table_iter_fut, degree_table_iter_fut).await?;
594
595 let mut pinned_table_iter = std::pin::pin!(table_iter);
596 let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
597
598 let mut rows = vec![];
601 let mut degree_rows = vec![];
602 let mut inconsistency_happened = false;
603 loop {
604 let (row, degree_row) =
605 join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
606 let (row, degree_row) = match (row, degree_row) {
607 (None, None) => break,
608 (None, Some(_)) => {
609 inconsistency_happened = true;
610 consistency_panic!(
611 "mismatched row and degree table of join key: {:?}, degree table has more rows",
612 &key
613 );
614 break;
615 }
616 (Some(_), None) => {
617 inconsistency_happened = true;
618 consistency_panic!(
619 "mismatched row and degree table of join key: {:?}, input table has more rows",
620 &key
621 );
622 break;
623 }
624 (Some(r), Some(d)) => (r, d),
625 };
626
627 let row = row?;
628 let degree_row = degree_row?;
629 rows.push(row);
630 degree_rows.push(degree_row);
631 }
632
633 if inconsistency_happened {
634 assert_ne!(rows.len(), degree_rows.len());
636
637 let row_iter = stream::iter(rows.into_iter()).peekable();
638 let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
639 pin_mut!(row_iter);
640 pin_mut!(degree_row_iter);
641
642 loop {
643 match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
644 (None, _) | (_, None) => break,
645 (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
646 Ordering::Greater => {
647 degree_row_iter.next().await;
648 }
649 Ordering::Less => {
650 row_iter.next().await;
651 }
652 Ordering::Equal => {
653 let row =
654 row_iter.next().await.expect("we matched some(row) above");
655 let degree_row = degree_row_iter
656 .next()
657 .await
658 .expect("we matched some(degree_row) above");
659 let pk = row
660 .as_ref()
661 .project(&self.state.pk_indices)
662 .memcmp_serialize(&self.pk_serializer);
663 let degree_i64 = degree_row
664 .datum_at(degree_row.len() - 1)
665 .expect("degree should not be NULL");
666 let inequality_key = self
667 .inequality_key_desc
668 .as_ref()
669 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
670 entry_state
671 .insert(
672 pk,
673 E::encode(&JoinRow::new(
674 row.row(),
675 degree_i64.into_int64() as u64,
676 )),
677 inequality_key,
678 )
679 .with_context(|| self.state.error_context(row.row()))?;
680 }
681 },
682 }
683 }
684 } else {
685 assert_eq!(rows.len(), degree_rows.len());
690
691 #[for_await]
692 for (row, degree_row) in
693 stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
694 {
695 let pk1 = row.key();
696 let pk2 = degree_row.key();
697 debug_assert_eq!(
698 pk1, pk2,
699 "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
700 );
701 let pk = row
702 .as_ref()
703 .project(&self.state.pk_indices)
704 .memcmp_serialize(&self.pk_serializer);
705 let inequality_key = self
706 .inequality_key_desc
707 .as_ref()
708 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
709 let degree_i64 = degree_row
710 .datum_at(degree_row.len() - 1)
711 .expect("degree should not be NULL");
712 entry_state
713 .insert(
714 pk,
715 E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
716 inequality_key,
717 )
718 .with_context(|| self.state.error_context(row.row()))?;
719 }
720 }
721 } else {
722 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
723 &(Bound::Unbounded, Bound::Unbounded);
724 let table_iter = self
725 .state
726 .table
727 .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
728 .await?;
729
730 #[for_await]
731 for entry in table_iter {
732 let row = entry?;
733 let pk = row
734 .as_ref()
735 .project(&self.state.pk_indices)
736 .memcmp_serialize(&self.pk_serializer);
737 let inequality_key = self
738 .inequality_key_desc
739 .as_ref()
740 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
741 entry_state
742 .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
743 .with_context(|| self.state.error_context(row.row()))?;
744 }
745 };
746
747 Ok(entry_state)
748 }
749
750 pub async fn flush(
751 &mut self,
752 epoch: EpochPair,
753 ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
754 self.metrics.flush();
755 let state_post_commit = self.state.table.commit(epoch).await?;
756 let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
757 Some(degree_state.table.commit(epoch).await?)
758 } else {
759 None
760 };
761 Ok(JoinHashMapPostCommit {
762 state: state_post_commit,
763 degree_state: degree_state_post_commit,
764 inner: &mut self.inner,
765 })
766 }
767
768 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
769 self.state.table.try_flush().await?;
770 if let Some(degree_state) = &mut self.degree_state {
771 degree_state.table.try_flush().await?;
772 }
773 Ok(())
774 }
775
776 pub fn insert_handle_degree(
777 &mut self,
778 key: &K,
779 value: JoinRow<impl Row>,
780 ) -> StreamExecutorResult<()> {
781 if self.need_degree_table {
782 self.insert(key, value)
783 } else {
784 self.insert_row(key, value.row)
785 }
786 }
787
788 pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
790 let pk = self.serialize_pk_from_row(&value.row);
791
792 let inequality_key = self
793 .inequality_key_desc
794 .as_ref()
795 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
796
797 if self.inner.contains(key) {
800 let mut entry = self.inner.get_mut(key).expect("checked contains");
802 entry
803 .insert(pk, E::encode(&value), inequality_key)
804 .with_context(|| self.state.error_context(&value.row))?;
805 } else if self.pk_contained_in_jk {
806 self.metrics.insert_cache_miss_count += 1;
808 let mut entry: JoinEntryState<E> = JoinEntryState::default();
809 entry
810 .insert(pk, E::encode(&value), inequality_key)
811 .with_context(|| self.state.error_context(&value.row))?;
812 self.update_state(key, entry.into());
813 }
814
815 if let Some(degree_state) = self.degree_state.as_mut() {
817 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
818 self.state.table.insert(row);
819 degree_state.table.insert(degree);
820 } else {
821 self.state.table.insert(value.row);
822 }
823 Ok(())
824 }
825
826 pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
829 let join_row = JoinRow::new(&value, 0);
830 self.insert(key, join_row)?;
831 Ok(())
832 }
833
834 pub fn delete_handle_degree(
835 &mut self,
836 key: &K,
837 value: JoinRow<impl Row>,
838 ) -> StreamExecutorResult<()> {
839 if self.need_degree_table {
840 self.delete(key, value)
841 } else {
842 self.delete_row(key, value.row)
843 }
844 }
845
846 pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
848 if let Some(mut entry) = self.inner.get_mut(key) {
849 let pk = (&value.row)
850 .project(&self.state.pk_indices)
851 .memcmp_serialize(&self.pk_serializer);
852 let inequality_key = self
853 .inequality_key_desc
854 .as_ref()
855 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
856 entry
857 .remove(pk, inequality_key.as_ref())
858 .with_context(|| self.state.error_context(&value.row))?;
859 }
860
861 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
863 self.state.table.delete(row);
864 let degree_state = self.degree_state.as_mut().expect("degree table missing");
865 degree_state.table.delete(degree);
866 Ok(())
867 }
868
869 pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
872 if let Some(mut entry) = self.inner.get_mut(key) {
873 let pk = (&value)
874 .project(&self.state.pk_indices)
875 .memcmp_serialize(&self.pk_serializer);
876
877 let inequality_key = self
878 .inequality_key_desc
879 .as_ref()
880 .map(|desc| desc.serialize_inequal_key_from_row(&value));
881 entry
882 .remove(pk, inequality_key.as_ref())
883 .with_context(|| self.state.error_context(&value))?;
884 }
885
886 self.state.table.delete(value);
888 Ok(())
889 }
890
891 pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
893 self.inner.put(key.clone(), HashValueWrapper(Some(state)));
894 }
895
896 pub fn evict(&mut self) {
898 self.inner.evict();
899 }
900
901 pub fn entry_count(&self) -> usize {
903 self.inner.len()
904 }
905
906 pub fn null_matched(&self) -> &K::Bitmap {
907 &self.null_matched
908 }
909
910 pub fn table_id(&self) -> u32 {
911 self.state.table.table_id()
912 }
913
914 pub fn join_key_data_types(&self) -> &[DataType] {
915 &self.join_key_data_types
916 }
917
918 pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
922 let desc = self
923 .inequality_key_desc
924 .as_ref()
925 .expect("inequality key desc missing");
926 row.datum_at(desc.idx).is_none()
927 }
928
929 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
933 self.inequality_key_desc
934 .as_ref()
935 .expect("inequality key desc missing")
936 .serialize_inequal_key_from_row(&row)
937 }
938
939 pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
940 row.project(&self.state.pk_indices)
941 .memcmp_serialize(&self.pk_serializer)
942 }
943}
944
945#[must_use]
946pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
947 state: StateTablePostCommit<'a, S>,
948 degree_state: Option<StateTablePostCommit<'a, S>>,
949 inner: &'a mut JoinHashMapInner<K, E>,
950}
951
952use risingwave_common_estimate_size::KvSize;
953use thiserror::Error;
954
955use super::*;
956use crate::executor::prelude::{Stream, try_stream};
957
958#[derive(Default)]
964pub struct JoinEntryState<E: JoinEncoding> {
965 cached: JoinRowSet<PkType, E::EncodedRow>,
967 inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
969 kv_heap_size: KvSize,
970}
971
972impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
973 fn estimated_heap_size(&self) -> usize {
974 self.kv_heap_size.size()
977 }
978}
979
980#[derive(Error, Debug)]
981pub enum JoinEntryError {
982 #[error("double inserting a join state entry")]
983 Occupied,
984 #[error("removing a join state entry but it is not in the cache")]
985 Remove,
986 #[error("retrieving a pk from the inequality index but it is not in the cache")]
987 InequalIndex,
988}
989
990impl<E: JoinEncoding> JoinEntryState<E> {
991 pub fn insert(
993 &mut self,
994 key: PkType,
995 value: E::EncodedRow,
996 inequality_key: Option<InequalKeyType>,
997 ) -> Result<&mut E::EncodedRow, JoinEntryError> {
998 let mut removed = false;
999 if !enable_strict_consistency() {
1000 if let Some(old_value) = self.cached.remove(&key) {
1002 if let Some(inequality_key) = inequality_key.as_ref() {
1003 self.remove_pk_from_inequality_index(&key, inequality_key);
1004 }
1005 self.kv_heap_size.sub(&key, &old_value);
1006 removed = true;
1007 }
1008 }
1009
1010 self.kv_heap_size.add(&key, &value);
1011
1012 if let Some(inequality_key) = inequality_key {
1013 self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1014 }
1015 let ret = self.cached.try_insert(key.clone(), value);
1016
1017 if !enable_strict_consistency() {
1018 assert!(ret.is_ok(), "we have removed existing entry, if any");
1019 if removed {
1020 consistency_error!(?key, "double inserting a join state entry");
1022 }
1023 }
1024
1025 ret.map_err(|_| JoinEntryError::Occupied)
1026 }
1027
1028 pub fn remove(
1030 &mut self,
1031 pk: PkType,
1032 inequality_key: Option<&InequalKeyType>,
1033 ) -> Result<(), JoinEntryError> {
1034 if let Some(value) = self.cached.remove(&pk) {
1035 self.kv_heap_size.sub(&pk, &value);
1036 if let Some(inequality_key) = inequality_key {
1037 self.remove_pk_from_inequality_index(&pk, inequality_key);
1038 }
1039 Ok(())
1040 } else if enable_strict_consistency() {
1041 Err(JoinEntryError::Remove)
1042 } else {
1043 consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1044 Ok(())
1045 }
1046 }
1047
1048 fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1049 if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1050 if pk_set.remove(pk).is_none() {
1051 if enable_strict_consistency() {
1052 panic!("removing a pk that it not in the inequality index");
1053 } else {
1054 consistency_error!(?pk, "removing a pk that it not in the inequality index");
1055 };
1056 } else {
1057 self.kv_heap_size.sub(pk, &());
1058 }
1059 if pk_set.is_empty() {
1060 self.inequality_index.remove(inequality_key);
1061 }
1062 }
1063 }
1064
1065 fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1066 if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1067 let pk_size = pk.estimated_size();
1068 if pk_set.try_insert(pk, ()).is_err() {
1069 if enable_strict_consistency() {
1070 panic!("inserting a pk that it already in the inequality index");
1071 } else {
1072 consistency_error!("inserting a pk that it already in the inequality index");
1073 };
1074 } else {
1075 self.kv_heap_size.add_size(pk_size);
1076 }
1077 } else {
1078 let mut pk_set = JoinRowSet::default();
1079 pk_set.try_insert(pk, ()).expect("pk set should be empty");
1080 self.inequality_index
1081 .try_insert(inequality_key, pk_set)
1082 .expect("pk set should be empty");
1083 }
1084 }
1085
1086 pub fn get(
1087 &self,
1088 pk: &PkType,
1089 data_types: &[DataType],
1090 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1091 self.cached
1092 .get(pk)
1093 .map(|encoded| encoded.decode(data_types))
1094 }
1095
1096 pub fn values_mut<'a>(
1102 &'a mut self,
1103 data_types: &'a [DataType],
1104 ) -> impl Iterator<
1105 Item = (
1106 &'a mut E::EncodedRow,
1107 StreamExecutorResult<JoinRow<OwnedRow>>,
1108 ),
1109 > + 'a {
1110 self.cached.values_mut().map(|encoded| {
1111 let decoded = encoded.decode(data_types);
1112 (encoded, decoded)
1113 })
1114 }
1115
1116 pub fn len(&self) -> usize {
1117 self.cached.len()
1118 }
1119
1120 pub fn range_by_inequality<'a, R>(
1122 &'a self,
1123 range: R,
1124 data_types: &'a [DataType],
1125 ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<OwnedRow>>> + 'a
1126 where
1127 R: RangeBounds<InequalKeyType> + 'a,
1128 {
1129 self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1130 pk_set
1131 .keys()
1132 .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1133 })
1134 }
1135
1136 pub fn upper_bound_by_inequality<'a>(
1138 &'a self,
1139 bound: Bound<&InequalKeyType>,
1140 data_types: &'a [DataType],
1141 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1142 if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1143 if let Some(pk) = pk_set.first_key_sorted() {
1144 self.get_by_indexed_pk(pk, data_types)
1145 } else {
1146 panic!("pk set for a index record must has at least one element");
1147 }
1148 } else {
1149 None
1150 }
1151 }
1152
1153 pub fn get_by_indexed_pk(
1154 &self,
1155 pk: &PkType,
1156 data_types: &[DataType],
1157 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>>
1158where {
1159 if let Some(value) = self.cached.get(pk) {
1160 Some(value.decode(data_types))
1161 } else if enable_strict_consistency() {
1162 Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1163 } else {
1164 consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1165 None
1166 }
1167 }
1168
1169 pub fn lower_bound_by_inequality<'a>(
1171 &'a self,
1172 bound: Bound<&InequalKeyType>,
1173 data_types: &'a [DataType],
1174 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1175 if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1176 if let Some(pk) = pk_set.first_key_sorted() {
1177 self.get_by_indexed_pk(pk, data_types)
1178 } else {
1179 panic!("pk set for a index record must has at least one element");
1180 }
1181 } else {
1182 None
1183 }
1184 }
1185
1186 pub fn get_first_by_inequality<'a>(
1187 &'a self,
1188 inequality_key: &InequalKeyType,
1189 data_types: &'a [DataType],
1190 ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1191 if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1192 if let Some(pk) = pk_set.first_key_sorted() {
1193 self.get_by_indexed_pk(pk, data_types)
1194 } else {
1195 panic!("pk set for a index record must has at least one element");
1196 }
1197 } else {
1198 None
1199 }
1200 }
1201
1202 pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1203 &self.inequality_index
1204 }
1205}
1206
1207#[cfg(test)]
1208mod tests {
1209 use itertools::Itertools;
1210 use risingwave_common::array::*;
1211 use risingwave_common::util::iter_util::ZipEqDebug;
1212
1213 use super::*;
1214 use crate::executor::MemoryEncoding;
1215
1216 fn insert_chunk<E: JoinEncoding>(
1217 managed_state: &mut JoinEntryState<E>,
1218 pk_indices: &[usize],
1219 col_types: &[DataType],
1220 inequality_key_idx: Option<usize>,
1221 data_chunk: &DataChunk,
1222 ) {
1223 let pk_col_type = pk_indices
1224 .iter()
1225 .map(|idx| col_types[*idx].clone())
1226 .collect_vec();
1227 let pk_serializer =
1228 OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1229 let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1230 let inequality_key_serializer = inequality_key_type
1231 .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1232 for row_ref in data_chunk.rows() {
1233 let row: OwnedRow = row_ref.into_owned_row();
1234 let value_indices = (0..row.len() - 1).collect_vec();
1235 let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1236 let pk = OwnedRow::new(pk)
1238 .project(&value_indices)
1239 .memcmp_serialize(&pk_serializer);
1240 let inequality_key = inequality_key_idx.map(|idx| {
1241 (&row)
1242 .project(&[idx])
1243 .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1244 });
1245 let join_row = JoinRow { row, degree: 0 };
1246 managed_state
1247 .insert(pk, E::encode(&join_row), inequality_key)
1248 .unwrap();
1249 }
1250 }
1251
1252 fn check<E: JoinEncoding>(
1253 managed_state: &mut JoinEntryState<E>,
1254 col_types: &[DataType],
1255 col1: &[i64],
1256 col2: &[i64],
1257 ) {
1258 for ((_, matched_row), (d1, d2)) in managed_state
1259 .values_mut(col_types)
1260 .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1261 {
1262 let matched_row = matched_row.unwrap();
1263 assert_eq!(matched_row.row[0], Some(ScalarImpl::Int64(*d1)));
1264 assert_eq!(matched_row.row[1], Some(ScalarImpl::Int64(*d2)));
1265 assert_eq!(matched_row.degree, 0);
1266 }
1267 }
1268
1269 #[tokio::test]
1270 async fn test_managed_join_state() {
1271 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1272 let col_types = vec![DataType::Int64, DataType::Int64];
1273 let pk_indices = [0];
1274
1275 let col1 = [3, 2, 1];
1276 let col2 = [4, 5, 6];
1277 let data_chunk1 = DataChunk::from_pretty(
1278 "I I
1279 3 4
1280 2 5
1281 1 6",
1282 );
1283
1284 insert_chunk::<MemoryEncoding>(
1286 &mut managed_state,
1287 &pk_indices,
1288 &col_types,
1289 None,
1290 &data_chunk1,
1291 );
1292 check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1293
1294 let col1 = [1, 2, 3, 4, 5];
1296 let col2 = [6, 5, 4, 9, 8];
1297 let data_chunk2 = DataChunk::from_pretty(
1298 "I I
1299 5 8
1300 4 9",
1301 );
1302 insert_chunk(
1303 &mut managed_state,
1304 &pk_indices,
1305 &col_types,
1306 None,
1307 &data_chunk2,
1308 );
1309 check(&mut managed_state, &col_types, &col1, &col2);
1310 }
1311
1312 #[tokio::test]
1313 async fn test_managed_join_state_w_inequality_index() {
1314 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1315 let col_types = vec![DataType::Int64, DataType::Int64];
1316 let pk_indices = [0];
1317 let inequality_key_idx = Some(1);
1318 let inequality_key_serializer =
1319 OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1320
1321 let col1 = [3, 2, 1];
1322 let col2 = [4, 5, 5];
1323 let data_chunk1 = DataChunk::from_pretty(
1324 "I I
1325 3 4
1326 2 5
1327 1 5",
1328 );
1329
1330 insert_chunk(
1332 &mut managed_state,
1333 &pk_indices,
1334 &col_types,
1335 inequality_key_idx,
1336 &data_chunk1,
1337 );
1338 check(&mut managed_state, &col_types, &col1, &col2);
1339 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1340 .memcmp_serialize(&inequality_key_serializer);
1341 let row = managed_state
1342 .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1343 .unwrap()
1344 .unwrap();
1345 assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1346 let row = managed_state
1347 .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1348 .unwrap()
1349 .unwrap();
1350 assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1351
1352 let col1 = [1, 2, 3, 4, 5];
1354 let col2 = [5, 5, 4, 4, 8];
1355 let data_chunk2 = DataChunk::from_pretty(
1356 "I I
1357 5 8
1358 4 4",
1359 );
1360 insert_chunk(
1361 &mut managed_state,
1362 &pk_indices,
1363 &col_types,
1364 inequality_key_idx,
1365 &data_chunk2,
1366 );
1367 check(&mut managed_state, &col_types, &col1, &col2);
1368
1369 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1370 .memcmp_serialize(&inequality_key_serializer);
1371 let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1372 assert!(row.is_none());
1373 }
1374}