1use std::ops::{Bound, Deref, DerefMut};
16use std::sync::Arc;
17
18use anyhow::Context;
19use futures::StreamExt;
20use futures_async_stream::for_await;
21use join_row_set::JoinRowSet;
22use risingwave_common::bitmap::Bitmap;
23use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
24use risingwave_common::metrics::LabelGuardedIntCounter;
25use risingwave_common::row::{OwnedRow, Row, RowExt};
26use risingwave_common::types::{DataType, ScalarImpl};
27use risingwave_common::util::epoch::EpochPair;
28use risingwave_common::util::row_serde::OrderedRowSerde;
29use risingwave_common::util::sort_util::OrderType;
30use risingwave_common_estimate_size::EstimateSize;
31use risingwave_storage::StateStore;
32use risingwave_storage::store::PrefetchOptions;
33
34use super::row::{CachedJoinRow, DegreeType, build_degree_row};
35use crate::cache::ManagedLruCache;
36use crate::common::metrics::MetricsInfo;
37use crate::common::table::state_table::{StateTable, StateTablePostCommit};
38use crate::consistency::{consistency_error, enable_strict_consistency};
39use crate::executor::error::StreamExecutorResult;
40use crate::executor::join::row::JoinRow;
41use crate::executor::monitor::StreamingMetrics;
42use crate::executor::{JoinEncoding, StreamExecutorError};
43use crate::task::{ActorId, AtomicU64Ref, FragmentId};
44
45type PkType = Vec<u8>;
47pub type HashValueType<E> = Box<JoinEntryState<E>>;
48
49impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
50 fn estimated_heap_size(&self) -> usize {
51 self.as_ref().estimated_heap_size()
52 }
53}
54
55struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
61
62pub(crate) enum CacheResult<E: JoinEncoding> {
63 NeverMatch, Miss, Hit(HashValueType<E>), }
67
68impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
69 fn estimated_heap_size(&self) -> usize {
70 self.0.estimated_heap_size()
71 }
72}
73
74impl<E: JoinEncoding> HashValueWrapper<E> {
75 const MESSAGE: &'static str = "the state should always be `Some`";
76
77 pub fn take(&mut self) -> HashValueType<E> {
79 self.0.take().expect(Self::MESSAGE)
80 }
81}
82
83impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
84 type Target = HashValueType<E>;
85
86 fn deref(&self) -> &Self::Target {
87 self.0.as_ref().expect(Self::MESSAGE)
88 }
89}
90
91impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
92 fn deref_mut(&mut self) -> &mut Self::Target {
93 self.0.as_mut().expect(Self::MESSAGE)
94 }
95}
96
97type JoinHashMapInner<K, E> = ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher>;
98
99pub struct JoinHashMapMetrics {
100 lookup_miss_count: usize,
103 total_lookup_count: usize,
104 insert_cache_miss_count: usize,
106
107 join_lookup_total_count_metric: LabelGuardedIntCounter,
109 join_lookup_miss_count_metric: LabelGuardedIntCounter,
110 join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
111}
112
113impl JoinHashMapMetrics {
114 pub fn new(
115 metrics: &StreamingMetrics,
116 actor_id: ActorId,
117 fragment_id: FragmentId,
118 side: &'static str,
119 join_table_id: TableId,
120 ) -> Self {
121 let actor_id = actor_id.to_string();
122 let fragment_id = fragment_id.to_string();
123 let join_table_id = join_table_id.to_string();
124 let join_lookup_total_count_metric = metrics
125 .join_lookup_total_count
126 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
127 let join_lookup_miss_count_metric = metrics
128 .join_lookup_miss_count
129 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
130 let join_insert_cache_miss_count_metrics = metrics
131 .join_insert_cache_miss_count
132 .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
133
134 Self {
135 lookup_miss_count: 0,
136 total_lookup_count: 0,
137 insert_cache_miss_count: 0,
138 join_lookup_total_count_metric,
139 join_lookup_miss_count_metric,
140 join_insert_cache_miss_count_metrics,
141 }
142 }
143
144 pub fn inc_lookup(&mut self) {
145 self.total_lookup_count += 1;
146 }
147
148 pub fn inc_lookup_miss(&mut self) {
149 self.lookup_miss_count += 1;
150 }
151
152 pub fn inc_insert_cache_miss(&mut self) {
153 self.insert_cache_miss_count += 1;
154 }
155
156 pub fn flush(&mut self) {
157 self.join_lookup_total_count_metric
158 .inc_by(self.total_lookup_count as u64);
159 self.join_lookup_miss_count_metric
160 .inc_by(self.lookup_miss_count as u64);
161 self.join_insert_cache_miss_count_metrics
162 .inc_by(self.insert_cache_miss_count as u64);
163 self.total_lookup_count = 0;
164 self.lookup_miss_count = 0;
165 self.insert_cache_miss_count = 0;
166 }
167}
168
169pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
170 inner: JoinHashMapInner<K, E>,
172 join_key_data_types: Vec<DataType>,
174 null_matched: K::Bitmap,
176 pk_serializer: OrderedRowSerde,
178 state: TableInner<S>,
180 degree_state: Option<TableInner<S>>,
211 need_degree_table: bool,
214 pk_contained_in_jk: bool,
216 metrics: JoinHashMapMetrics,
218 _marker: std::marker::PhantomData<E>,
219}
220
221impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
222 pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
223 (&self.state.order_key_indices, &mut self.degree_state)
224 }
225
226 pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
233 &'a mut self,
234 key: &'a K,
235 ) -> StreamExecutorResult<(
236 impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
237 &'a [usize],
238 &'a mut Option<TableInner<S>>,
239 )> {
240 let degree_state = &mut self.degree_state;
241 let (order_key_indices, pk_indices, state_table) = (
242 &self.state.order_key_indices,
243 &self.state.pk_indices,
244 &mut self.state.table,
245 );
246 let degrees = if let Some(degree_state) = degree_state {
247 Some(fetch_degrees(key, &self.join_key_data_types, °ree_state.table).await?)
248 } else {
249 None
250 };
251 let stream = into_stream(
252 &self.join_key_data_types,
253 pk_indices,
254 &self.pk_serializer,
255 state_table,
256 key,
257 degrees,
258 );
259 Ok((stream, order_key_indices, &mut self.degree_state))
260 }
261}
262
263#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
264pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
265 join_key_data_types: &'a [DataType],
266 pk_indices: &'a [usize],
267 pk_serializer: &'a OrderedRowSerde,
268 state_table: &'a StateTable<S>,
269 key: &'a K,
270 degrees: Option<Vec<DegreeType>>,
271) {
272 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
273 let decoded_key = key.deserialize(join_key_data_types)?;
274 let table_iter = state_table
275 .iter_with_prefix_respecting_watermark(&decoded_key, sub_range, PrefetchOptions::default())
276 .await?;
277
278 #[for_await]
279 for (i, entry) in table_iter.enumerate() {
280 let encoded_row = entry?;
281 let encoded_pk = encoded_row
282 .as_ref()
283 .project(pk_indices)
284 .memcmp_serialize(pk_serializer);
285 let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
286 yield (encoded_pk, join_row);
287 }
288}
289
290async fn fetch_degrees<K: HashKey, S: StateStore>(
314 key: &K,
315 join_key_data_types: &[DataType],
316 degree_state_table: &StateTable<S>,
317) -> StreamExecutorResult<Vec<DegreeType>> {
318 let key = key.deserialize(join_key_data_types)?;
319 let mut degrees = vec![];
320 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
321 let table_iter = degree_state_table
322 .iter_with_prefix_respecting_watermark(key, sub_range, PrefetchOptions::default())
323 .await?;
324 let degree_col_idx = degree_col_idx_in_row(degree_state_table);
325 #[for_await]
326 for entry in table_iter {
327 let degree_row = entry?;
328 debug_assert!(
329 degree_row.len() > degree_col_idx,
330 "degree row should have at least pk_len + 1 columns"
331 );
332 let degree_i64 = degree_row
333 .datum_at(degree_col_idx)
334 .expect("degree should not be NULL");
335 degrees.push(degree_i64.into_int64() as u64);
336 }
337 Ok(degrees)
338}
339
340fn degree_col_idx_in_row<S: StateStore>(degree_state_table: &StateTable<S>) -> usize {
341 let degree_col_idx = degree_state_table.pk_indices().len();
343 match degree_state_table.value_indices() {
344 Some(value_indices) => value_indices
345 .iter()
346 .position(|idx| *idx == degree_col_idx)
347 .expect("degree column should be included in value indices"),
348 None => degree_col_idx,
349 }
350}
351
352pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
356 order_key_indices: &[usize],
357 degree_state: &mut TableInner<S>,
358 matched_row: &mut JoinRow<impl Row>,
359) {
360 let inequality_idx = degree_state.degree_inequality_idx;
361 let old_degree_row = build_degree_row(
362 order_key_indices,
363 matched_row.degree,
364 inequality_idx,
365 &matched_row.row,
366 );
367 if INCREMENT {
368 matched_row.degree += 1;
369 } else {
370 matched_row.degree -= 1;
372 }
373 let new_degree_row = build_degree_row(
374 order_key_indices,
375 matched_row.degree,
376 inequality_idx,
377 &matched_row.row,
378 );
379 degree_state.table.update(old_degree_row, new_degree_row);
380}
381
382pub struct TableInner<S: StateStore> {
383 pub(crate) pk_indices: Vec<usize>,
385 join_key_indices: Vec<usize>,
387 order_key_indices: Vec<usize>,
392 pub(crate) degree_inequality_idx: Option<usize>,
396 pub(crate) table: StateTable<S>,
397}
398
399impl<S: StateStore> TableInner<S> {
400 pub fn new(
401 pk_indices: Vec<usize>,
402 join_key_indices: Vec<usize>,
403 table: StateTable<S>,
404 degree_inequality_idx: Option<usize>,
405 ) -> Self {
406 let order_key_indices = table.pk_indices().to_vec();
407 Self {
408 pk_indices,
409 join_key_indices,
410 order_key_indices,
411 degree_inequality_idx,
412 table,
413 }
414 }
415
416 fn error_context(&self, row: &impl Row) -> String {
417 let pk = row.project(&self.pk_indices);
418 let jk = row.project(&self.join_key_indices);
419 format!(
420 "join key: {}, pk: {}, row: {}, state_table_id: {}",
421 jk.display(),
422 pk.display(),
423 row.display(),
424 self.table.table_id()
425 )
426 }
427}
428
429impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
430 #[expect(clippy::too_many_arguments)]
432 pub fn new(
433 watermark_sequence: AtomicU64Ref,
434 join_key_data_types: Vec<DataType>,
435 state_join_key_indices: Vec<usize>,
436 state_all_data_types: Vec<DataType>,
437 state_table: StateTable<S>,
438 state_pk_indices: Vec<usize>,
439 degree_state: Option<TableInner<S>>,
440 null_matched: K::Bitmap,
441 pk_contained_in_jk: bool,
442 metrics: Arc<StreamingMetrics>,
443 actor_id: ActorId,
444 fragment_id: FragmentId,
445 side: &'static str,
446 ) -> Self {
447 let pk_data_types = state_pk_indices
449 .iter()
450 .map(|i| state_all_data_types[*i].clone())
451 .collect();
452 let pk_serializer = OrderedRowSerde::new(
453 pk_data_types,
454 vec![OrderType::ascending(); state_pk_indices.len()],
455 );
456
457 let join_table_id = state_table.table_id();
458 let state = TableInner {
459 pk_indices: state_pk_indices,
460 join_key_indices: state_join_key_indices,
461 order_key_indices: state_table.pk_indices().to_vec(),
462 degree_inequality_idx: None,
463 table: state_table,
464 };
465
466 let need_degree_table = degree_state.is_some();
467
468 let metrics_info = MetricsInfo::new(
469 metrics.clone(),
470 join_table_id,
471 actor_id,
472 format!("hash join {}", side),
473 );
474
475 let cache = ManagedLruCache::unbounded_with_hasher(
476 watermark_sequence,
477 metrics_info,
478 PrecomputedBuildHasher,
479 );
480
481 Self {
482 inner: cache,
483 join_key_data_types,
484 null_matched,
485 pk_serializer,
486 state,
487 degree_state,
488 need_degree_table,
489 pk_contained_in_jk,
490 metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
491 _marker: std::marker::PhantomData,
492 }
493 }
494
495 pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
496 self.state.table.init_epoch(epoch).await?;
497 if let Some(degree_state) = &mut self.degree_state {
498 degree_state.table.init_epoch(epoch).await?;
499 }
500 Ok(())
501 }
502}
503
504impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
505 pub async fn post_yield_barrier(
506 self,
507 vnode_bitmap: Option<Arc<Bitmap>>,
508 ) -> StreamExecutorResult<Option<bool>> {
509 let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
510 if let Some(degree_state) = self.degree_state {
511 let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
512 }
513 let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
514 if cache_may_stale.unwrap_or(false) {
515 self.inner.clear();
516 }
517 Ok(cache_may_stale)
518 }
519}
520impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
521 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
522 self.state.table.update_watermark(watermark.clone());
524 if let Some(degree_state) = &mut self.degree_state {
525 degree_state.table.update_watermark(watermark);
526 }
527 }
528
529 pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
538 self.metrics.total_lookup_count += 1;
539 if self.inner.contains(key) {
540 tracing::trace!("hit cache for join key: {:?}", key);
541 let mut state = self.inner.peek_mut(key).expect("checked contains");
544 CacheResult::Hit(state.take())
545 } else {
546 self.metrics.lookup_miss_count += 1;
547 tracing::trace!("miss cache for join key: {:?}", key);
548 CacheResult::Miss
549 }
550 }
551
552 pub async fn flush(
553 &mut self,
554 epoch: EpochPair,
555 ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
556 self.metrics.flush();
557 let state_post_commit = self.state.table.commit(epoch).await?;
558 let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
559 Some(degree_state.table.commit(epoch).await?)
560 } else {
561 None
562 };
563 Ok(JoinHashMapPostCommit {
564 state: state_post_commit,
565 degree_state: degree_state_post_commit,
566 inner: &mut self.inner,
567 })
568 }
569
570 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
571 self.state.table.try_flush().await?;
572 if let Some(degree_state) = &mut self.degree_state {
573 degree_state.table.try_flush().await?;
574 }
575 Ok(())
576 }
577
578 pub fn insert_handle_degree(
579 &mut self,
580 key: &K,
581 value: JoinRow<impl Row>,
582 ) -> StreamExecutorResult<()> {
583 if self.need_degree_table {
584 self.insert(key, value)
585 } else {
586 self.insert_row(key, value.row)
587 }
588 }
589
590 pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
592 let pk = self.serialize_pk_from_row(&value.row);
593
594 if self.inner.contains(key) {
597 let mut entry = self.inner.get_mut(key).expect("checked contains");
599 entry
600 .insert(pk, E::encode(&value))
601 .with_context(|| self.state.error_context(&value.row))?;
602 } else if self.pk_contained_in_jk {
603 self.metrics.insert_cache_miss_count += 1;
605 let mut entry: JoinEntryState<E> = JoinEntryState::default();
606 entry
607 .insert(pk, E::encode(&value))
608 .with_context(|| self.state.error_context(&value.row))?;
609 self.update_state(key, entry.into());
610 }
611
612 if let Some(degree_state) = self.degree_state.as_mut() {
614 let (row, degree) = value.to_table_rows(
615 &self.state.order_key_indices,
616 degree_state.degree_inequality_idx,
617 );
618 self.state.table.insert(row);
619 degree_state.table.insert(degree);
620 } else {
621 self.state.table.insert(value.row);
622 }
623 Ok(())
624 }
625
626 pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
629 let join_row = JoinRow::new(&value, 0);
630 self.insert(key, join_row)?;
631 Ok(())
632 }
633
634 pub fn delete_row_in_mem(&mut self, key: &K, value: &impl Row) -> StreamExecutorResult<()> {
635 if let Some(mut entry) = self.inner.get_mut(key) {
636 let pk = (&value)
637 .project(&self.state.pk_indices)
638 .memcmp_serialize(&self.pk_serializer);
639 entry
640 .remove(pk)
641 .with_context(|| self.state.error_context(&value))?;
642 }
643 Ok(())
644 }
645
646 pub fn delete_handle_degree(
647 &mut self,
648 key: &K,
649 value: JoinRow<impl Row>,
650 ) -> StreamExecutorResult<()> {
651 if self.need_degree_table {
652 self.delete(key, value)
653 } else {
654 self.delete_row(key, value.row)
655 }
656 }
657
658 pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
660 self.delete_row_in_mem(key, &value.row)?;
661
662 let degree_state = self.degree_state.as_mut().expect("degree table missing");
664 let (row, degree) = value.to_table_rows(
665 &self.state.order_key_indices,
666 degree_state.degree_inequality_idx,
667 );
668 self.state.table.delete(row);
669 degree_state.table.delete(degree);
670 Ok(())
671 }
672
673 pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
676 self.delete_row_in_mem(key, &value)?;
677
678 self.state.table.delete(value);
680 Ok(())
681 }
682
683 pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
685 self.inner.put(key.clone(), HashValueWrapper(Some(state)));
686 }
687
688 pub fn evict(&mut self) {
690 self.inner.evict();
691 }
692
693 pub fn entry_count(&self) -> usize {
695 self.inner.len()
696 }
697
698 pub fn null_matched(&self) -> &K::Bitmap {
699 &self.null_matched
700 }
701
702 pub fn table_id(&self) -> TableId {
703 self.state.table.table_id()
704 }
705
706 pub fn join_key_data_types(&self) -> &[DataType] {
707 &self.join_key_data_types
708 }
709
710 pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
711 row.project(&self.state.pk_indices)
712 .memcmp_serialize(&self.pk_serializer)
713 }
714}
715
716#[must_use]
717pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
718 state: StateTablePostCommit<'a, S>,
719 degree_state: Option<StateTablePostCommit<'a, S>>,
720 inner: &'a mut JoinHashMapInner<K, E>,
721}
722
723use risingwave_common::catalog::TableId;
724use risingwave_common_estimate_size::KvSize;
725use thiserror::Error;
726
727use super::*;
728use crate::executor::prelude::{Stream, try_stream};
729
730#[derive(Default)]
736pub struct JoinEntryState<E: JoinEncoding> {
737 cached: JoinRowSet<PkType, E::EncodedRow>,
739 kv_heap_size: KvSize,
740}
741
742impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
743 fn estimated_heap_size(&self) -> usize {
744 self.kv_heap_size.size()
747 }
748}
749
750#[derive(Error, Debug)]
751pub enum JoinEntryError {
752 #[error("double inserting a join state entry")]
753 Occupied,
754 #[error("removing a join state entry but it is not in the cache")]
755 Remove,
756}
757
758impl<E: JoinEncoding> JoinEntryState<E> {
759 pub fn insert(
761 &mut self,
762 key: PkType,
763 value: E::EncodedRow,
764 ) -> Result<&mut E::EncodedRow, JoinEntryError> {
765 let mut removed = false;
766 if !enable_strict_consistency() {
767 if let Some(old_value) = self.cached.remove(&key) {
769 self.kv_heap_size.sub(&key, &old_value);
770 removed = true;
771 }
772 }
773
774 self.kv_heap_size.add(&key, &value);
775
776 let ret = self.cached.try_insert(key.clone(), value);
777
778 if !enable_strict_consistency() {
779 assert!(ret.is_ok(), "we have removed existing entry, if any");
780 if removed {
781 consistency_error!(?key, "double inserting a join state entry");
783 }
784 }
785
786 ret.map_err(|_| JoinEntryError::Occupied)
787 }
788
789 pub fn remove(&mut self, pk: PkType) -> Result<(), JoinEntryError> {
791 if let Some(value) = self.cached.remove(&pk) {
792 self.kv_heap_size.sub(&pk, &value);
793 Ok(())
794 } else if enable_strict_consistency() {
795 Err(JoinEntryError::Remove)
796 } else {
797 consistency_error!(?pk, "removing a join state entry but it's not in the cache");
798 Ok(())
799 }
800 }
801
802 pub fn get(
803 &self,
804 pk: &PkType,
805 data_types: &[DataType],
806 ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
807 self.cached
808 .get(pk)
809 .map(|encoded| encoded.decode(data_types))
810 }
811
812 pub fn values_mut<'a>(
818 &'a mut self,
819 data_types: &'a [DataType],
820 ) -> impl Iterator<
821 Item = (
822 &'a mut E::EncodedRow,
823 StreamExecutorResult<JoinRow<E::DecodedRow>>,
824 ),
825 > + 'a {
826 self.cached.values_mut().map(|encoded| {
827 let decoded = encoded.decode(data_types);
828 (encoded, decoded)
829 })
830 }
831
832 pub fn len(&self) -> usize {
833 self.cached.len()
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use itertools::Itertools;
840 use risingwave_common::array::*;
841 use risingwave_common::types::ScalarRefImpl;
842 use risingwave_common::util::iter_util::ZipEqDebug;
843
844 use super::*;
845 use crate::executor::MemoryEncoding;
846
847 fn insert_chunk<E: JoinEncoding>(
848 managed_state: &mut JoinEntryState<E>,
849 pk_indices: &[usize],
850 col_types: &[DataType],
851 data_chunk: &DataChunk,
852 ) {
853 let pk_col_type = pk_indices
854 .iter()
855 .map(|idx| col_types[*idx].clone())
856 .collect_vec();
857 let pk_serializer =
858 OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
859 for row_ref in data_chunk.rows() {
860 let row: OwnedRow = row_ref.into_owned_row();
861 let value_indices = (0..row.len() - 1).collect_vec();
862 let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
863 let pk = OwnedRow::new(pk)
865 .project(&value_indices)
866 .memcmp_serialize(&pk_serializer);
867 let join_row = JoinRow { row, degree: 0 };
868 managed_state.insert(pk, E::encode(&join_row)).unwrap();
869 }
870 }
871
872 fn check<E: JoinEncoding>(
873 managed_state: &mut JoinEntryState<E>,
874 col_types: &[DataType],
875 col1: &[i64],
876 col2: &[i64],
877 ) {
878 for ((_, matched_row), (d1, d2)) in managed_state
879 .values_mut(col_types)
880 .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
881 {
882 let matched_row = matched_row.unwrap();
883 assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
884 assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
885 assert_eq!(matched_row.degree, 0);
886 }
887 }
888
889 #[tokio::test]
890 async fn test_managed_join_state() {
891 let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
892 let col_types = vec![DataType::Int64, DataType::Int64];
893 let pk_indices = [0];
894
895 let col1 = [3, 2, 1];
896 let col2 = [4, 5, 6];
897 let data_chunk1 = DataChunk::from_pretty(
898 "I I
899 3 4
900 2 5
901 1 6",
902 );
903
904 insert_chunk::<MemoryEncoding>(&mut managed_state, &pk_indices, &col_types, &data_chunk1);
906 check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
907
908 let col1 = [1, 2, 3, 4, 5];
910 let col2 = [6, 5, 4, 9, 8];
911 let data_chunk2 = DataChunk::from_pretty(
912 "I I
913 5 8
914 4 9",
915 );
916 insert_chunk(&mut managed_state, &pk_indices, &col_types, &data_chunk2);
917 check(&mut managed_state, &col_types, &col1, &col2);
918 }
919}