1use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use risingwave_common::bitmap::Bitmap;
25use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
26use risingwave_common::metrics::LabelGuardedIntCounter;
27use risingwave_common::row::{OwnedRow, Row, RowExt, once};
28use risingwave_common::types::{DataType, ScalarImpl};
29use risingwave_common::util::epoch::EpochPair;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::row_serde::OrderedRowSerde;
32use risingwave_common::util::sort_util::OrderType;
33use risingwave_common_estimate_size::EstimateSize;
34use risingwave_storage::StateStore;
35use risingwave_storage::store::PrefetchOptions;
36use risingwave_storage::table::KeyedRow;
37use thiserror_ext::AsReport;
38
39use super::row::{CachedJoinRow, DegreeType};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::error::StreamExecutorResult;
45use crate::executor::join::row::JoinRow;
46use crate::executor::monitor::StreamingMetrics;
47use crate::executor::{JoinEncoding, StreamExecutorError};
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type HashValueType<E> = Box<JoinEntryState<E>>;
55
56impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
57 fn estimated_heap_size(&self) -> usize {
58 self.as_ref().estimated_heap_size()
59 }
60}
61
62struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
68
69pub(crate) enum CacheResult<E: JoinEncoding> {
70 NeverMatch, Miss, Hit(HashValueType<E>), }
74
75impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
76 fn estimated_heap_size(&self) -> usize {
77 self.0.estimated_heap_size()
78 }
79}
80
81impl<E: JoinEncoding> HashValueWrapper<E> {
82 const MESSAGE: &'static str = "the state should always be `Some`";
83
84 pub fn take(&mut self) -> HashValueType<E> {
86 self.0.take().expect(Self::MESSAGE)
87 }
88}
89
90impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
91 type Target = HashValueType<E>;
92
93 fn deref(&self) -> &Self::Target {
94 self.0.as_ref().expect(Self::MESSAGE)
95 }
96}
97
98impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
99 fn deref_mut(&mut self) -> &mut Self::Target {
100 self.0.as_mut().expect(Self::MESSAGE)
101 }
102}
103
104type JoinHashMapInner<K, E> = ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher>;
105
106pub struct JoinHashMapMetrics {
107 lookup_miss_count: usize,
110 total_lookup_count: usize,
111 insert_cache_miss_count: usize,
113
114 join_lookup_total_count_metric: LabelGuardedIntCounter,
116 join_lookup_miss_count_metric: LabelGuardedIntCounter,
117 join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
118}
119
120impl JoinHashMapMetrics {
121 pub fn new(
122 metrics: &StreamingMetrics,
123 actor_id: ActorId,
124 fragment_id: FragmentId,
125 side: &'static str,
126 join_table_id: TableId,
127 ) -> Self {
128 let actor_id = actor_id.to_string();
129 let fragment_id = fragment_id.to_string();
130 let join_table_id = join_table_id.to_string();
131 let join_lookup_total_count_metric = metrics
132 .join_lookup_total_count
133 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
134 let join_lookup_miss_count_metric = metrics
135 .join_lookup_miss_count
136 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
137 let join_insert_cache_miss_count_metrics = metrics
138 .join_insert_cache_miss_count
139 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
140
141 Self {
142 lookup_miss_count: 0,
143 total_lookup_count: 0,
144 insert_cache_miss_count: 0,
145 join_lookup_total_count_metric,
146 join_lookup_miss_count_metric,
147 join_insert_cache_miss_count_metrics,
148 }
149 }
150
151 pub fn flush(&mut self) {
152 self.join_lookup_total_count_metric
153 .inc_by(self.total_lookup_count as u64);
154 self.join_lookup_miss_count_metric
155 .inc_by(self.lookup_miss_count as u64);
156 self.join_insert_cache_miss_count_metrics
157 .inc_by(self.insert_cache_miss_count as u64);
158 self.total_lookup_count = 0;
159 self.lookup_miss_count = 0;
160 self.insert_cache_miss_count = 0;
161 }
162}
163
164struct InequalityKeyDesc {
166 idx: usize,
167 serializer: OrderedRowSerde,
168}
169
170impl InequalityKeyDesc {
171 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
173 let indices = vec![self.idx];
174 let inequality_key = row.project(&indices);
175 inequality_key.memcmp_serialize(&self.serializer)
176 }
177}
178
179pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
180 inner: JoinHashMapInner<K, E>,
182 join_key_data_types: Vec<DataType>,
184 null_matched: K::Bitmap,
186 pk_serializer: OrderedRowSerde,
188 state: TableInner<S>,
190 degree_state: Option<TableInner<S>>,
221 need_degree_table: bool,
224 pk_contained_in_jk: bool,
226 inequality_key_desc: Option<InequalityKeyDesc>,
228 metrics: JoinHashMapMetrics,
230 _marker: std::marker::PhantomData<E>,
231}
232
233impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
234 pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
235 (&self.state.order_key_indices, &mut self.degree_state)
236 }
237
238 pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
245 &'a mut self,
246 key: &'a K,
247 ) -> StreamExecutorResult<(
248 impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
249 &'a [usize],
250 &'a mut Option<TableInner<S>>,
251 )> {
252 let degree_state = &mut self.degree_state;
253 let (order_key_indices, pk_indices, state_table) = (
254 &self.state.order_key_indices,
255 &self.state.pk_indices,
256 &mut self.state.table,
257 );
258 let degrees = if let Some(degree_state) = degree_state {
259 Some(fetch_degrees(key, &self.join_key_data_types, °ree_state.table).await?)
260 } else {
261 None
262 };
263 let stream = into_stream(
264 &self.join_key_data_types,
265 pk_indices,
266 &self.pk_serializer,
267 state_table,
268 key,
269 degrees,
270 );
271 Ok((stream, order_key_indices, &mut self.degree_state))
272 }
273}
274
275#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
276pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
277 join_key_data_types: &'a [DataType],
278 pk_indices: &'a [usize],
279 pk_serializer: &'a OrderedRowSerde,
280 state_table: &'a StateTable<S>,
281 key: &'a K,
282 degrees: Option<Vec<DegreeType>>,
283) {
284 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
285 let decoded_key = key.deserialize(join_key_data_types)?;
286 let table_iter = state_table
287 .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
288 .await?;
289
290 #[for_await]
291 for (i, entry) in table_iter.enumerate() {
292 let encoded_row = entry?;
293 let encoded_pk = encoded_row
294 .as_ref()
295 .project(pk_indices)
296 .memcmp_serialize(pk_serializer);
297 let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
298 yield (encoded_pk, join_row);
299 }
300}
301
302async fn fetch_degrees<K: HashKey, S: StateStore>(
326 key: &K,
327 join_key_data_types: &[DataType],
328 degree_state_table: &StateTable<S>,
329) -> StreamExecutorResult<Vec<DegreeType>> {
330 let key = key.deserialize(join_key_data_types)?;
331 let mut degrees = vec![];
332 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
333 let table_iter = degree_state_table
334 .iter_with_prefix(key, sub_range, PrefetchOptions::default())
335 .await?;
336 #[for_await]
337 for entry in table_iter {
338 let degree_row = entry?;
339 let degree_i64 = degree_row
340 .datum_at(degree_row.len() - 1)
341 .expect("degree should not be NULL");
342 degrees.push(degree_i64.into_int64() as u64);
343 }
344 Ok(degrees)
345}
346
347pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
351 order_key_indices: &[usize],
352 degree_state: &mut TableInner<S>,
353 matched_row: &mut JoinRow<impl Row>,
354) {
355 let old_degree_row = (&matched_row.row)
356 .project(order_key_indices)
357 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
358 if INCREMENT {
359 matched_row.degree += 1;
360 } else {
361 matched_row.degree -= 1;
363 }
364 let new_degree_row = (&matched_row.row)
365 .project(order_key_indices)
366 .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
367 degree_state.table.update(old_degree_row, new_degree_row);
368}
369
370pub struct TableInner<S: StateStore> {
371 pk_indices: Vec<usize>,
373 join_key_indices: Vec<usize>,
375 order_key_indices: Vec<usize>,
380 pub(crate) table: StateTable<S>,
381}
382
383impl<S: StateStore> TableInner<S> {
384 pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
385 let order_key_indices = table.pk_indices().to_vec();
386 Self {
387 pk_indices,
388 join_key_indices,
389 order_key_indices,
390 table,
391 }
392 }
393
394 fn error_context(&self, row: &impl Row) -> String {
395 let pk = row.project(&self.pk_indices);
396 let jk = row.project(&self.join_key_indices);
397 format!(
398 "join key: {}, pk: {}, row: {}, state_table_id: {}",
399 jk.display(),
400 pk.display(),
401 row.display(),
402 self.table.table_id()
403 )
404 }
405}
406
407impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
408 #[allow(clippy::too_many_arguments)]
410 pub fn new(
411 watermark_sequence: AtomicU64Ref,
412 join_key_data_types: Vec<DataType>,
413 state_join_key_indices: Vec<usize>,
414 state_all_data_types: Vec<DataType>,
415 state_table: StateTable<S>,
416 state_pk_indices: Vec<usize>,
417 degree_state: Option<TableInner<S>>,
418 null_matched: K::Bitmap,
419 pk_contained_in_jk: bool,
420 inequality_key_idx: Option<usize>,
421 metrics: Arc<StreamingMetrics>,
422 actor_id: ActorId,
423 fragment_id: FragmentId,
424 side: &'static str,
425 ) -> Self {
426 let pk_data_types = state_pk_indices
428 .iter()
429 .map(|i| state_all_data_types[*i].clone())
430 .collect();
431 let pk_serializer = OrderedRowSerde::new(
432 pk_data_types,
433 vec![OrderType::ascending(); state_pk_indices.len()],
434 );
435
436 let inequality_key_desc = inequality_key_idx.map(|idx| {
437 let serializer = OrderedRowSerde::new(
438 vec![state_all_data_types[idx].clone()],
439 vec![OrderType::ascending()],
440 );
441 InequalityKeyDesc { idx, serializer }
442 });
443
444 let join_table_id = state_table.table_id();
445 let state = TableInner {
446 pk_indices: state_pk_indices,
447 join_key_indices: state_join_key_indices,
448 order_key_indices: state_table.pk_indices().to_vec(),
449 table: state_table,
450 };
451
452 let need_degree_table = degree_state.is_some();
453
454 let metrics_info = MetricsInfo::new(
455 metrics.clone(),
456 join_table_id,
457 actor_id,
458 format!("hash join {}", side),
459 );
460
461 let cache = ManagedLruCache::unbounded_with_hasher(
462 watermark_sequence,
463 metrics_info,
464 PrecomputedBuildHasher,
465 );
466
467 Self {
468 inner: cache,
469 join_key_data_types,
470 null_matched,
471 pk_serializer,
472 state,
473 degree_state,
474 need_degree_table,
475 pk_contained_in_jk,
476 inequality_key_desc,
477 metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
478 _marker: std::marker::PhantomData,
479 }
480 }
481
482 pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
483 self.state.table.init_epoch(epoch).await?;
484 if let Some(degree_state) = &mut self.degree_state {
485 degree_state.table.init_epoch(epoch).await?;
486 }
487 Ok(())
488 }
489}
490
491impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
492 pub async fn post_yield_barrier(
493 self,
494 vnode_bitmap: Option<Arc<Bitmap>>,
495 ) -> StreamExecutorResult<Option<bool>> {
496 let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
497 if let Some(degree_state) = self.degree_state {
498 let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
499 }
500 let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
501 if cache_may_stale.unwrap_or(false) {
502 self.inner.clear();
503 }
504 Ok(cache_may_stale)
505 }
506}
507impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
508 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
509 self.state.table.update_watermark(watermark.clone());
511 if let Some(degree_state) = &mut self.degree_state {
512 degree_state.table.update_watermark(watermark);
513 }
514 }
515
516 pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
525 self.metrics.total_lookup_count += 1;
526 if self.inner.contains(key) {
527 tracing::trace!("hit cache for join key: {:?}", key);
528 let mut state = self.inner.peek_mut(key).expect("checked contains");
531 CacheResult::Hit(state.take())
532 } else {
533 tracing::trace!("miss cache for join key: {:?}", key);
534 CacheResult::Miss
535 }
536 }
537
538 pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
547 self.metrics.total_lookup_count += 1;
548 let state = if self.inner.contains(key) {
549 let mut state = self.inner.peek_mut(key).unwrap();
552 state.take()
553 } else {
554 self.metrics.lookup_miss_count += 1;
555 self.fetch_cached_state(key).await?.into()
556 };
557 Ok(state)
558 }
559
560 async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
563 let key = key.deserialize(&self.join_key_data_types)?;
564
565 let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
566
567 if self.need_degree_table {
568 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
569 &(Bound::Unbounded, Bound::Unbounded);
570 let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
571 &key,
572 sub_range,
573 PrefetchOptions::default(),
574 );
575 let degree_state = self.degree_state.as_ref().unwrap();
576 let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
577 &key,
578 sub_range,
579 PrefetchOptions::default(),
580 );
581
582 let (table_iter, degree_table_iter) =
583 try_join(table_iter_fut, degree_table_iter_fut).await?;
584
585 let mut pinned_table_iter = std::pin::pin!(table_iter);
586 let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
587
588 let mut rows = vec![];
591 let mut degree_rows = vec![];
592 let mut inconsistency_happened = false;
593 loop {
594 let (row, degree_row) =
595 join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
596 let (row, degree_row) = match (row, degree_row) {
597 (None, None) => break,
598 (None, Some(_)) => {
599 inconsistency_happened = true;
600 consistency_panic!(
601 "mismatched row and degree table of join key: {:?}, degree table has more rows",
602 &key
603 );
604 break;
605 }
606 (Some(_), None) => {
607 inconsistency_happened = true;
608 consistency_panic!(
609 "mismatched row and degree table of join key: {:?}, input table has more rows",
610 &key
611 );
612 break;
613 }
614 (Some(r), Some(d)) => (r, d),
615 };
616
617 let row = row?;
618 let degree_row = degree_row?;
619 rows.push(row);
620 degree_rows.push(degree_row);
621 }
622
623 if inconsistency_happened {
624 assert_ne!(rows.len(), degree_rows.len());
626
627 let row_iter = stream::iter(rows.into_iter()).peekable();
628 let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
629 pin_mut!(row_iter);
630 pin_mut!(degree_row_iter);
631
632 loop {
633 match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
634 (None, _) | (_, None) => break,
635 (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
636 Ordering::Greater => {
637 degree_row_iter.next().await;
638 }
639 Ordering::Less => {
640 row_iter.next().await;
641 }
642 Ordering::Equal => {
643 let row =
644 row_iter.next().await.expect("we matched some(row) above");
645 let degree_row = degree_row_iter
646 .next()
647 .await
648 .expect("we matched some(degree_row) above");
649 let pk = row
650 .as_ref()
651 .project(&self.state.pk_indices)
652 .memcmp_serialize(&self.pk_serializer);
653 let degree_i64 = degree_row
654 .datum_at(degree_row.len() - 1)
655 .expect("degree should not be NULL");
656 let inequality_key = self
657 .inequality_key_desc
658 .as_ref()
659 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
660 entry_state
661 .insert(
662 pk,
663 E::encode(&JoinRow::new(
664 row.row(),
665 degree_i64.into_int64() as u64,
666 )),
667 inequality_key,
668 )
669 .with_context(|| self.state.error_context(row.row()))?;
670 }
671 },
672 }
673 }
674 } else {
675 assert_eq!(rows.len(), degree_rows.len());
680
681 #[for_await]
682 for (row, degree_row) in
683 stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
684 {
685 let row: KeyedRow<_> = row;
686 let degree_row: KeyedRow<_> = degree_row;
687
688 let pk1 = row.key();
689 let pk2 = degree_row.key();
690 debug_assert_eq!(
691 pk1, pk2,
692 "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
693 );
694 let pk = row
695 .as_ref()
696 .project(&self.state.pk_indices)
697 .memcmp_serialize(&self.pk_serializer);
698 let inequality_key = self
699 .inequality_key_desc
700 .as_ref()
701 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
702 let degree_i64 = degree_row
703 .datum_at(degree_row.len() - 1)
704 .expect("degree should not be NULL");
705 entry_state
706 .insert(
707 pk,
708 E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
709 inequality_key,
710 )
711 .with_context(|| self.state.error_context(row.row()))?;
712 }
713 }
714 } else {
715 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
716 &(Bound::Unbounded, Bound::Unbounded);
717 let table_iter = self
718 .state
719 .table
720 .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
721 .await?;
722
723 #[for_await]
724 for entry in table_iter {
725 let row: KeyedRow<_> = entry?;
726 let pk = row
727 .as_ref()
728 .project(&self.state.pk_indices)
729 .memcmp_serialize(&self.pk_serializer);
730 let inequality_key = self
731 .inequality_key_desc
732 .as_ref()
733 .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
734 entry_state
735 .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
736 .with_context(|| self.state.error_context(row.row()))?;
737 }
738 };
739
740 Ok(entry_state)
741 }
742
743 pub async fn flush(
744 &mut self,
745 epoch: EpochPair,
746 ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
747 self.metrics.flush();
748 let state_post_commit = self.state.table.commit(epoch).await?;
749 let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
750 Some(degree_state.table.commit(epoch).await?)
751 } else {
752 None
753 };
754 Ok(JoinHashMapPostCommit {
755 state: state_post_commit,
756 degree_state: degree_state_post_commit,
757 inner: &mut self.inner,
758 })
759 }
760
761 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
762 self.state.table.try_flush().await?;
763 if let Some(degree_state) = &mut self.degree_state {
764 degree_state.table.try_flush().await?;
765 }
766 Ok(())
767 }
768
769 pub fn insert_handle_degree(
770 &mut self,
771 key: &K,
772 value: JoinRow<impl Row>,
773 ) -> StreamExecutorResult<()> {
774 if self.need_degree_table {
775 self.insert(key, value)
776 } else {
777 self.insert_row(key, value.row)
778 }
779 }
780
781 pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
783 let pk = self.serialize_pk_from_row(&value.row);
784
785 let inequality_key = self
786 .inequality_key_desc
787 .as_ref()
788 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
789
790 if self.inner.contains(key) {
793 let mut entry = self.inner.get_mut(key).expect("checked contains");
795 entry
796 .insert(pk, E::encode(&value), inequality_key)
797 .with_context(|| self.state.error_context(&value.row))?;
798 } else if self.pk_contained_in_jk {
799 self.metrics.insert_cache_miss_count += 1;
801 let mut entry: JoinEntryState<E> = JoinEntryState::default();
802 entry
803 .insert(pk, E::encode(&value), inequality_key)
804 .with_context(|| self.state.error_context(&value.row))?;
805 self.update_state(key, entry.into());
806 }
807
808 if let Some(degree_state) = self.degree_state.as_mut() {
810 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
811 self.state.table.insert(row);
812 degree_state.table.insert(degree);
813 } else {
814 self.state.table.insert(value.row);
815 }
816 Ok(())
817 }
818
819 pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
822 let join_row = JoinRow::new(&value, 0);
823 self.insert(key, join_row)?;
824 Ok(())
825 }
826
827 pub fn delete_handle_degree(
828 &mut self,
829 key: &K,
830 value: JoinRow<impl Row>,
831 ) -> StreamExecutorResult<()> {
832 if self.need_degree_table {
833 self.delete(key, value)
834 } else {
835 self.delete_row(key, value.row)
836 }
837 }
838
839 pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
841 if let Some(mut entry) = self.inner.get_mut(key) {
842 let pk = (&value.row)
843 .project(&self.state.pk_indices)
844 .memcmp_serialize(&self.pk_serializer);
845 let inequality_key = self
846 .inequality_key_desc
847 .as_ref()
848 .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
849 entry
850 .remove(pk, inequality_key.as_ref())
851 .with_context(|| self.state.error_context(&value.row))?;
852 }
853
854 let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
856 self.state.table.delete(row);
857 let degree_state = self.degree_state.as_mut().expect("degree table missing");
858 degree_state.table.delete(degree);
859 Ok(())
860 }
861
862 pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
865 if let Some(mut entry) = self.inner.get_mut(key) {
866 let pk = (&value)
867 .project(&self.state.pk_indices)
868 .memcmp_serialize(&self.pk_serializer);
869
870 let inequality_key = self
871 .inequality_key_desc
872 .as_ref()
873 .map(|desc| desc.serialize_inequal_key_from_row(&value));
874 entry
875 .remove(pk, inequality_key.as_ref())
876 .with_context(|| self.state.error_context(&value))?;
877 }
878
879 self.state.table.delete(value);
881 Ok(())
882 }
883
884 pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
886 self.inner.put(key.clone(), HashValueWrapper(Some(state)));
887 }
888
889 pub fn evict(&mut self) {
891 self.inner.evict();
892 }
893
894 pub fn entry_count(&self) -> usize {
896 self.inner.len()
897 }
898
899 pub fn null_matched(&self) -> &K::Bitmap {
900 &self.null_matched
901 }
902
903 pub fn table_id(&self) -> TableId {
904 self.state.table.table_id()
905 }
906
907 pub fn join_key_data_types(&self) -> &[DataType] {
908 &self.join_key_data_types
909 }
910
911 pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
915 let desc = self
916 .inequality_key_desc
917 .as_ref()
918 .expect("inequality key desc missing");
919 row.datum_at(desc.idx).is_none()
920 }
921
922 pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
926 self.inequality_key_desc
927 .as_ref()
928 .expect("inequality key desc missing")
929 .serialize_inequal_key_from_row(&row)
930 }
931
932 pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
933 row.project(&self.state.pk_indices)
934 .memcmp_serialize(&self.pk_serializer)
935 }
936}
937
938#[must_use]
939pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
940 state: StateTablePostCommit<'a, S>,
941 degree_state: Option<StateTablePostCommit<'a, S>>,
942 inner: &'a mut JoinHashMapInner<K, E>,
943}
944
945use risingwave_common::catalog::TableId;
946use risingwave_common_estimate_size::KvSize;
947use thiserror::Error;
948
949use super::*;
950use crate::executor::prelude::{Stream, try_stream};
951
952#[derive(Default)]
958pub struct JoinEntryState<E: JoinEncoding> {
959 cached: JoinRowSet<PkType, E::EncodedRow>,
961 inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
963 kv_heap_size: KvSize,
964}
965
966impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
967 fn estimated_heap_size(&self) -> usize {
968 self.kv_heap_size.size()
971 }
972}
973
974#[derive(Error, Debug)]
975pub enum JoinEntryError {
976 #[error("double inserting a join state entry")]
977 Occupied,
978 #[error("removing a join state entry but it is not in the cache")]
979 Remove,
980 #[error("retrieving a pk from the inequality index but it is not in the cache")]
981 InequalIndex,
982}
983
984impl<E: JoinEncoding> JoinEntryState<E> {
985 pub fn insert(
987 &mut self,
988 key: PkType,
989 value: E::EncodedRow,
990 inequality_key: Option<InequalKeyType>,
991 ) -> Result<&mut E::EncodedRow, JoinEntryError> {
992 let mut removed = false;
993 if !enable_strict_consistency() {
994 if let Some(old_value) = self.cached.remove(&key) {
996 if let Some(inequality_key) = inequality_key.as_ref() {
997 self.remove_pk_from_inequality_index(&key, inequality_key);
998 }
999 self.kv_heap_size.sub(&key, &old_value);
1000 removed = true;
1001 }
1002 }
1003
1004 self.kv_heap_size.add(&key, &value);
1005
1006 if let Some(inequality_key) = inequality_key {
1007 self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1008 }
1009 let ret = self.cached.try_insert(key.clone(), value);
1010
1011 if !enable_strict_consistency() {
1012 assert!(ret.is_ok(), "we have removed existing entry, if any");
1013 if removed {
1014 consistency_error!(?key, "double inserting a join state entry");
1016 }
1017 }
1018
1019 ret.map_err(|_| JoinEntryError::Occupied)
1020 }
1021
1022 pub fn remove(
1024 &mut self,
1025 pk: PkType,
1026 inequality_key: Option<&InequalKeyType>,
1027 ) -> Result<(), JoinEntryError> {
1028 if let Some(value) = self.cached.remove(&pk) {
1029 self.kv_heap_size.sub(&pk, &value);
1030 if let Some(inequality_key) = inequality_key {
1031 self.remove_pk_from_inequality_index(&pk, inequality_key);
1032 }
1033 Ok(())
1034 } else if enable_strict_consistency() {
1035 Err(JoinEntryError::Remove)
1036 } else {
1037 consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1038 Ok(())
1039 }
1040 }
1041
1042 fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1043 if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1044 if pk_set.remove(pk).is_none() {
1045 if enable_strict_consistency() {
1046 panic!("removing a pk that it not in the inequality index");
1047 } else {
1048 consistency_error!(?pk, "removing a pk that it not in the inequality index");
1049 };
1050 } else {
1051 self.kv_heap_size.sub(pk, &());
1052 }
1053 if pk_set.is_empty() {
1054 self.inequality_index.remove(inequality_key);
1055 }
1056 }
1057 }
1058
1059 fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1060 if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1061 let pk_size = pk.estimated_size();
1062 if pk_set.try_insert(pk, ()).is_err() {
1063 if enable_strict_consistency() {
1064 panic!("inserting a pk that it already in the inequality index");
1065 } else {
1066 consistency_error!("inserting a pk that it already in the inequality index");
1067 };
1068 } else {
1069 self.kv_heap_size.add_size(pk_size);
1070 }
1071 } else {
1072 let mut pk_set = JoinRowSet::default();
1073 pk_set.try_insert(pk, ()).expect("pk set should be empty");
1074 self.inequality_index
1075 .try_insert(inequality_key, pk_set)
1076 .expect("pk set should be empty");
1077 }
1078 }
1079
1080 pub fn get(
1081 &self,
1082 pk: &PkType,
1083 data_types: &[DataType],
1084 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1085 self.cached
1086 .get(pk)
1087 .map(|encoded| encoded.decode(data_types))
1088 }
1089
1090 pub fn values_mut<'a>(
1096 &'a mut self,
1097 data_types: &'a [DataType],
1098 ) -> impl Iterator<
1099 Item = (
1100 &'a mut E::EncodedRow,
1101 StreamExecutorResult<JoinRow<E::DecodedRow>>,
1102 ),
1103 > + 'a {
1104 self.cached.values_mut().map(|encoded| {
1105 let decoded = encoded.decode(data_types);
1106 (encoded, decoded)
1107 })
1108 }
1109
1110 pub fn len(&self) -> usize {
1111 self.cached.len()
1112 }
1113
1114 pub fn range_by_inequality<'a, R>(
1116 &'a self,
1117 range: R,
1118 data_types: &'a [DataType],
1119 ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<E::DecodedRow>>> + 'a
1120 where
1121 R: RangeBounds<InequalKeyType> + 'a,
1122 {
1123 self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1124 pk_set
1125 .keys()
1126 .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1127 })
1128 }
1129
1130 pub fn upper_bound_by_inequality<'a>(
1132 &'a self,
1133 bound: Bound<&InequalKeyType>,
1134 data_types: &'a [DataType],
1135 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1136 if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1137 if let Some(pk) = pk_set.first_key_sorted() {
1138 self.get_by_indexed_pk(pk, data_types)
1139 } else {
1140 panic!("pk set for a index record must has at least one element");
1141 }
1142 } else {
1143 None
1144 }
1145 }
1146
1147 pub fn get_by_indexed_pk(
1148 &self,
1149 pk: &PkType,
1150 data_types: &[DataType],
1151 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>>
1152where {
1153 if let Some(value) = self.cached.get(pk) {
1154 Some(value.decode(data_types))
1155 } else if enable_strict_consistency() {
1156 Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1157 } else {
1158 consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1159 None
1160 }
1161 }
1162
1163 pub fn lower_bound_by_inequality<'a>(
1165 &'a self,
1166 bound: Bound<&InequalKeyType>,
1167 data_types: &'a [DataType],
1168 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1169 if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1170 if let Some(pk) = pk_set.first_key_sorted() {
1171 self.get_by_indexed_pk(pk, data_types)
1172 } else {
1173 panic!("pk set for a index record must has at least one element");
1174 }
1175 } else {
1176 None
1177 }
1178 }
1179
1180 pub fn get_first_by_inequality<'a>(
1181 &'a self,
1182 inequality_key: &InequalKeyType,
1183 data_types: &'a [DataType],
1184 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1185 if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1186 if let Some(pk) = pk_set.first_key_sorted() {
1187 self.get_by_indexed_pk(pk, data_types)
1188 } else {
1189 panic!("pk set for a index record must has at least one element");
1190 }
1191 } else {
1192 None
1193 }
1194 }
1195
1196 pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1197 &self.inequality_index
1198 }
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203 use itertools::Itertools;
1204 use risingwave_common::array::*;
1205 use risingwave_common::types::ScalarRefImpl;
1206 use risingwave_common::util::iter_util::ZipEqDebug;
1207
1208 use super::*;
1209 use crate::executor::MemoryEncoding;
1210
1211 fn insert_chunk<E: JoinEncoding>(
1212 managed_state: &mut JoinEntryState<E>,
1213 pk_indices: &[usize],
1214 col_types: &[DataType],
1215 inequality_key_idx: Option<usize>,
1216 data_chunk: &DataChunk,
1217 ) {
1218 let pk_col_type = pk_indices
1219 .iter()
1220 .map(|idx| col_types[*idx].clone())
1221 .collect_vec();
1222 let pk_serializer =
1223 OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1224 let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1225 let inequality_key_serializer = inequality_key_type
1226 .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1227 for row_ref in data_chunk.rows() {
1228 let row: OwnedRow = row_ref.into_owned_row();
1229 let value_indices = (0..row.len() - 1).collect_vec();
1230 let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1231 let pk = OwnedRow::new(pk)
1233 .project(&value_indices)
1234 .memcmp_serialize(&pk_serializer);
1235 let inequality_key = inequality_key_idx.map(|idx| {
1236 (&row)
1237 .project(&[idx])
1238 .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1239 });
1240 let join_row = JoinRow { row, degree: 0 };
1241 managed_state
1242 .insert(pk, E::encode(&join_row), inequality_key)
1243 .unwrap();
1244 }
1245 }
1246
1247 fn check<E: JoinEncoding>(
1248 managed_state: &mut JoinEntryState<E>,
1249 col_types: &[DataType],
1250 col1: &[i64],
1251 col2: &[i64],
1252 ) {
1253 for ((_, matched_row), (d1, d2)) in managed_state
1254 .values_mut(col_types)
1255 .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1256 {
1257 let matched_row = matched_row.unwrap();
1258 assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
1259 assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
1260 assert_eq!(matched_row.degree, 0);
1261 }
1262 }
1263
1264 #[tokio::test]
1265 async fn test_managed_join_state() {
1266 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1267 let col_types = vec![DataType::Int64, DataType::Int64];
1268 let pk_indices = [0];
1269
1270 let col1 = [3, 2, 1];
1271 let col2 = [4, 5, 6];
1272 let data_chunk1 = DataChunk::from_pretty(
1273 "I I
1274 3 4
1275 2 5
1276 1 6",
1277 );
1278
1279 insert_chunk::<MemoryEncoding>(
1281 &mut managed_state,
1282 &pk_indices,
1283 &col_types,
1284 None,
1285 &data_chunk1,
1286 );
1287 check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1288
1289 let col1 = [1, 2, 3, 4, 5];
1291 let col2 = [6, 5, 4, 9, 8];
1292 let data_chunk2 = DataChunk::from_pretty(
1293 "I I
1294 5 8
1295 4 9",
1296 );
1297 insert_chunk(
1298 &mut managed_state,
1299 &pk_indices,
1300 &col_types,
1301 None,
1302 &data_chunk2,
1303 );
1304 check(&mut managed_state, &col_types, &col1, &col2);
1305 }
1306
1307 #[tokio::test]
1308 async fn test_managed_join_state_w_inequality_index() {
1309 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1310 let col_types = vec![DataType::Int64, DataType::Int64];
1311 let pk_indices = [0];
1312 let inequality_key_idx = Some(1);
1313 let inequality_key_serializer =
1314 OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1315
1316 let col1 = [3, 2, 1];
1317 let col2 = [4, 5, 5];
1318 let data_chunk1 = DataChunk::from_pretty(
1319 "I I
1320 3 4
1321 2 5
1322 1 5",
1323 );
1324
1325 insert_chunk(
1327 &mut managed_state,
1328 &pk_indices,
1329 &col_types,
1330 inequality_key_idx,
1331 &data_chunk1,
1332 );
1333 check(&mut managed_state, &col_types, &col1, &col2);
1334 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1335 .memcmp_serialize(&inequality_key_serializer);
1336 let row = managed_state
1337 .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1338 .unwrap()
1339 .unwrap();
1340 assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1341 let row = managed_state
1342 .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1343 .unwrap()
1344 .unwrap();
1345 assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1346
1347 let col1 = [1, 2, 3, 4, 5];
1349 let col2 = [5, 5, 4, 4, 8];
1350 let data_chunk2 = DataChunk::from_pretty(
1351 "I I
1352 5 8
1353 4 4",
1354 );
1355 insert_chunk(
1356 &mut managed_state,
1357 &pk_indices,
1358 &col_types,
1359 inequality_key_idx,
1360 &data_chunk2,
1361 );
1362 check(&mut managed_state, &col_types, &col1, &col2);
1363
1364 let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1365 .memcmp_serialize(&inequality_key_serializer);
1366 let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1367 assert!(row.is_none());
1368 }
1369}