risingwave_stream/executor/join/
hash_join.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use risingwave_common::bitmap::Bitmap;
25use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
26use risingwave_common::metrics::LabelGuardedIntCounter;
27use risingwave_common::row::{OwnedRow, Row, RowExt};
28use risingwave_common::types::{DataType, ScalarImpl};
29use risingwave_common::util::epoch::EpochPair;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::row_serde::OrderedRowSerde;
32use risingwave_common::util::sort_util::OrderType;
33use risingwave_common_estimate_size::EstimateSize;
34use risingwave_storage::StateStore;
35use risingwave_storage::store::PrefetchOptions;
36use risingwave_storage::table::KeyedRow;
37use thiserror_ext::AsReport;
38
39use super::row::{CachedJoinRow, DegreeType, build_degree_row};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::error::StreamExecutorResult;
45use crate::executor::join::row::JoinRow;
46use crate::executor::monitor::StreamingMetrics;
47use crate::executor::{JoinEncoding, StreamExecutorError};
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50/// Memcomparable encoding.
51type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type HashValueType<E> = Box<JoinEntryState<E>>;
55
56impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
57    fn estimated_heap_size(&self) -> usize {
58        self.as_ref().estimated_heap_size()
59    }
60}
61
62/// The wrapper for [`JoinEntryState`] which should be `Some` most of the time in the hash table.
63///
64/// When the executor is operating on the specific entry of the map, it can hold the ownership of
65/// the entry by taking the value out of the `Option`, instead of holding a mutable reference to the
66/// map, which can make the compiler happy.
67struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
68
69pub(crate) enum CacheResult<E: JoinEncoding> {
70    NeverMatch,            // Will never match, will not be in cache at all.
71    Miss,                  // Cache-miss
72    Hit(HashValueType<E>), // Cache-hit
73}
74
75impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
76    fn estimated_heap_size(&self) -> usize {
77        self.0.estimated_heap_size()
78    }
79}
80
81impl<E: JoinEncoding> HashValueWrapper<E> {
82    const MESSAGE: &'static str = "the state should always be `Some`";
83
84    /// Take the value out of the wrapper. Panic if the value is `None`.
85    pub fn take(&mut self) -> HashValueType<E> {
86        self.0.take().expect(Self::MESSAGE)
87    }
88}
89
90impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
91    type Target = HashValueType<E>;
92
93    fn deref(&self) -> &Self::Target {
94        self.0.as_ref().expect(Self::MESSAGE)
95    }
96}
97
98impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
99    fn deref_mut(&mut self) -> &mut Self::Target {
100        self.0.as_mut().expect(Self::MESSAGE)
101    }
102}
103
104type JoinHashMapInner<K, E> = ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher>;
105
106pub struct JoinHashMapMetrics {
107    /// Basic information
108    /// How many times have we hit the cache of join executor
109    lookup_miss_count: usize,
110    total_lookup_count: usize,
111    /// How many times have we miss the cache when insert row
112    insert_cache_miss_count: usize,
113
114    // Metrics
115    join_lookup_total_count_metric: LabelGuardedIntCounter,
116    join_lookup_miss_count_metric: LabelGuardedIntCounter,
117    join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
118}
119
120impl JoinHashMapMetrics {
121    pub fn new(
122        metrics: &StreamingMetrics,
123        actor_id: ActorId,
124        fragment_id: FragmentId,
125        side: &'static str,
126        join_table_id: TableId,
127    ) -> Self {
128        let actor_id = actor_id.to_string();
129        let fragment_id = fragment_id.to_string();
130        let join_table_id = join_table_id.to_string();
131        let join_lookup_total_count_metric = metrics
132            .join_lookup_total_count
133            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
134        let join_lookup_miss_count_metric = metrics
135            .join_lookup_miss_count
136            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
137        let join_insert_cache_miss_count_metrics = metrics
138            .join_insert_cache_miss_count
139            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
140
141        Self {
142            lookup_miss_count: 0,
143            total_lookup_count: 0,
144            insert_cache_miss_count: 0,
145            join_lookup_total_count_metric,
146            join_lookup_miss_count_metric,
147            join_insert_cache_miss_count_metrics,
148        }
149    }
150
151    pub fn flush(&mut self) {
152        self.join_lookup_total_count_metric
153            .inc_by(self.total_lookup_count as u64);
154        self.join_lookup_miss_count_metric
155            .inc_by(self.lookup_miss_count as u64);
156        self.join_insert_cache_miss_count_metrics
157            .inc_by(self.insert_cache_miss_count as u64);
158        self.total_lookup_count = 0;
159        self.lookup_miss_count = 0;
160        self.insert_cache_miss_count = 0;
161    }
162}
163
164/// Inequality key description for `AsOf` join.
165struct InequalityKeyDesc {
166    idx: usize,
167    serializer: OrderedRowSerde,
168}
169
170impl InequalityKeyDesc {
171    /// Serialize the inequality key from a row.
172    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
173        let indices = vec![self.idx];
174        let inequality_key = row.project(&indices);
175        inequality_key.memcmp_serialize(&self.serializer)
176    }
177}
178
179pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
180    /// Store the join states.
181    inner: JoinHashMapInner<K, E>,
182    /// Data types of the join key columns
183    join_key_data_types: Vec<DataType>,
184    /// Null safe bitmap for each join pair
185    null_matched: K::Bitmap,
186    /// The memcomparable serializer of primary key.
187    pk_serializer: OrderedRowSerde,
188    /// State table. Contains the data from upstream.
189    state: TableInner<S>,
190    /// Degree table.
191    ///
192    /// The degree is generated from the hash join executor.
193    /// Each row in `state` has a corresponding degree in `degree state`.
194    /// A degree value `d` in for a row means the row has `d` matched row in the other join side.
195    ///
196    /// It will only be used when needed in a side.
197    ///
198    /// - Full Outer: both side
199    /// - Left Outer/Semi/Anti: left side
200    /// - Right Outer/Semi/Anti: right side
201    /// - Inner: neither side.
202    ///
203    /// Should be set to `None` if `need_degree_table` was set to `false`.
204    ///
205    /// The degree of each row will tell us if we need to emit `NULL` for the row.
206    /// For instance, given `lhs LEFT JOIN rhs`,
207    /// If the degree of a row in `lhs` is 0, it means the row does not have a match in `rhs`.
208    /// If the degree of a row in `lhs` is 2, it means the row has two matches in `rhs`.
209    /// Now, when emitting the result of the join, we need to emit `NULL` for the row in `lhs` if
210    /// the degree is 0.
211    ///
212    /// Why don't just use a boolean value instead of a degree count?
213    /// Consider the case where we delete a matched record from `rhs`.
214    /// Since we can delete a record,
215    /// there must have been a record in `rhs` that matched the record in `lhs`.
216    /// So this value is `true`.
217    /// But we don't know how many records are matched after removing this record,
218    /// since we only stored a boolean value rather than the count.
219    /// Hence we need to store the count of matched records.
220    degree_state: Option<TableInner<S>>,
221    // TODO(kwannoel): Make this `const` instead.
222    /// If degree table is need
223    need_degree_table: bool,
224    /// Pk is part of the join key.
225    pk_contained_in_jk: bool,
226    /// Inequality key description for `AsOf` join.
227    inequality_key_desc: Option<InequalityKeyDesc>,
228    /// Metrics of the hash map
229    metrics: JoinHashMapMetrics,
230    _marker: std::marker::PhantomData<E>,
231}
232
233impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
234    pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
235        (&self.state.order_key_indices, &mut self.degree_state)
236    }
237
238    /// NOTE(kwannoel): This allows us to concurrently stream records from the `state_table`,
239    /// and update the degree table, without using `unsafe` code.
240    ///
241    /// This is because we obtain separate references to separate parts of the `JoinHashMap`,
242    /// instead of reusing the same reference to `JoinHashMap` for concurrent read access to `state_table`,
243    /// and write access to the degree table.
244    pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
245        &'a mut self,
246        key: &'a K,
247    ) -> StreamExecutorResult<(
248        impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
249        &'a [usize],
250        &'a mut Option<TableInner<S>>,
251    )> {
252        let degree_state = &mut self.degree_state;
253        let (order_key_indices, pk_indices, state_table) = (
254            &self.state.order_key_indices,
255            &self.state.pk_indices,
256            &mut self.state.table,
257        );
258        let degrees = if let Some(degree_state) = degree_state {
259            Some(fetch_degrees(key, &self.join_key_data_types, &degree_state.table).await?)
260        } else {
261            None
262        };
263        let stream = into_stream(
264            &self.join_key_data_types,
265            pk_indices,
266            &self.pk_serializer,
267            state_table,
268            key,
269            degrees,
270        );
271        Ok((stream, order_key_indices, &mut self.degree_state))
272    }
273}
274
275#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
276pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
277    join_key_data_types: &'a [DataType],
278    pk_indices: &'a [usize],
279    pk_serializer: &'a OrderedRowSerde,
280    state_table: &'a StateTable<S>,
281    key: &'a K,
282    degrees: Option<Vec<DegreeType>>,
283) {
284    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
285    let decoded_key = key.deserialize(join_key_data_types)?;
286    let table_iter = state_table
287        .iter_with_prefix_respecting_watermark(&decoded_key, sub_range, PrefetchOptions::default())
288        .await?;
289
290    #[for_await]
291    for (i, entry) in table_iter.enumerate() {
292        let encoded_row = entry?;
293        let encoded_pk = encoded_row
294            .as_ref()
295            .project(pk_indices)
296            .memcmp_serialize(pk_serializer);
297        let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
298        yield (encoded_pk, join_row);
299    }
300}
301
302/// We use this to fetch ALL degrees into memory.
303/// We use this instead of a streaming interface.
304/// It is necessary because we must update the `degree_state_table` concurrently.
305/// If we obtain the degrees in a stream,
306/// we will need to hold an immutable reference to the state table for the entire lifetime,
307/// preventing us from concurrently updating the state table.
308///
309/// The cost of fetching all degrees upfront is acceptable. We currently already do so
310/// in `fetch_cached_state`.
311/// The memory use should be limited since we only store a u64.
312///
313/// Let's say we have amplification of 1B, we will have 1B * 8 bytes ~= 8GB
314///
315/// We can also have further optimization, to permit breaking the streaming update,
316/// to flush the in-memory degrees, if this is proven to have high memory consumption.
317///
318/// TODO(kwannoel): Perhaps we can cache these separately from matched rows too.
319/// Because matched rows may occupy a larger capacity.
320///
321/// Argument for this:
322/// We only hit this when cache miss. When cache miss, we will have this as one off cost.
323/// Keeping this cached separately from matched rows is beneficial.
324/// Then we can evict matched rows, without touching the degrees.
325async fn fetch_degrees<K: HashKey, S: StateStore>(
326    key: &K,
327    join_key_data_types: &[DataType],
328    degree_state_table: &StateTable<S>,
329) -> StreamExecutorResult<Vec<DegreeType>> {
330    let key = key.deserialize(join_key_data_types)?;
331    let mut degrees = vec![];
332    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
333    let table_iter = degree_state_table
334        .iter_with_prefix_respecting_watermark(key, sub_range, PrefetchOptions::default())
335        .await?;
336    let degree_col_idx = degree_col_idx_in_row(degree_state_table);
337    #[for_await]
338    for entry in table_iter {
339        let degree_row = entry?;
340        debug_assert!(
341            degree_row.len() > degree_col_idx,
342            "degree row should have at least pk_len + 1 columns"
343        );
344        let degree_i64 = degree_row
345            .datum_at(degree_col_idx)
346            .expect("degree should not be NULL");
347        degrees.push(degree_i64.into_int64() as u64);
348    }
349    Ok(degrees)
350}
351
352fn degree_col_idx_in_row<S: StateStore>(degree_state_table: &StateTable<S>) -> usize {
353    // Degree column is at index pk_len in the full schema: [pk..., _degree, inequality?].
354    let degree_col_idx = degree_state_table.pk_indices().len();
355    match degree_state_table.value_indices() {
356        Some(value_indices) => value_indices
357            .iter()
358            .position(|idx| *idx == degree_col_idx)
359            .expect("degree column should be included in value indices"),
360        None => degree_col_idx,
361    }
362}
363
364// NOTE(kwannoel): This is not really specific to `TableInner`.
365// A degree table is `TableInner`, a `TableInner` might not be a degree table.
366// Hence we don't specify it in its impl block.
367pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
368    order_key_indices: &[usize],
369    degree_state: &mut TableInner<S>,
370    matched_row: &mut JoinRow<impl Row>,
371) {
372    let inequality_idx = degree_state.degree_inequality_idx;
373    let old_degree_row = build_degree_row(
374        order_key_indices,
375        matched_row.degree,
376        inequality_idx,
377        &matched_row.row,
378    );
379    if INCREMENT {
380        matched_row.degree += 1;
381    } else {
382        // DECREMENT
383        matched_row.degree -= 1;
384    }
385    let new_degree_row = build_degree_row(
386        order_key_indices,
387        matched_row.degree,
388        inequality_idx,
389        &matched_row.row,
390    );
391    degree_state.table.update(old_degree_row, new_degree_row);
392}
393
394pub struct TableInner<S: StateStore> {
395    /// Indices of the (cache) pk in a state row
396    pk_indices: Vec<usize>,
397    /// Indices of the join key in a state row
398    join_key_indices: Vec<usize>,
399    /// The order key of the join side has the following format:
400    /// | `join_key` ... | pk ... |
401    /// Where `join_key` contains all the columns not in the pk.
402    /// It should be a superset of the pk.
403    order_key_indices: Vec<usize>,
404    /// Optional: index of inequality column in the input row for degree table.
405    /// Used for inequality-based watermark cleaning of degree tables.
406    /// When present, the degree table schema is: [pk..., _degree, `inequality_val`].
407    pub(crate) degree_inequality_idx: Option<usize>,
408    pub(crate) table: StateTable<S>,
409}
410
411impl<S: StateStore> TableInner<S> {
412    pub fn new(
413        pk_indices: Vec<usize>,
414        join_key_indices: Vec<usize>,
415        table: StateTable<S>,
416        degree_inequality_idx: Option<usize>,
417    ) -> Self {
418        let order_key_indices = table.pk_indices().to_vec();
419        Self {
420            pk_indices,
421            join_key_indices,
422            order_key_indices,
423            degree_inequality_idx,
424            table,
425        }
426    }
427
428    fn error_context(&self, row: &impl Row) -> String {
429        let pk = row.project(&self.pk_indices);
430        let jk = row.project(&self.join_key_indices);
431        format!(
432            "join key: {}, pk: {}, row: {}, state_table_id: {}",
433            jk.display(),
434            pk.display(),
435            row.display(),
436            self.table.table_id()
437        )
438    }
439}
440
441impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
442    /// Create a [`JoinHashMap`] with the given LRU capacity.
443    #[allow(clippy::too_many_arguments)]
444    pub fn new(
445        watermark_sequence: AtomicU64Ref,
446        join_key_data_types: Vec<DataType>,
447        state_join_key_indices: Vec<usize>,
448        state_all_data_types: Vec<DataType>,
449        state_table: StateTable<S>,
450        state_pk_indices: Vec<usize>,
451        degree_state: Option<TableInner<S>>,
452        null_matched: K::Bitmap,
453        pk_contained_in_jk: bool,
454        inequality_key_idx: Option<usize>,
455        metrics: Arc<StreamingMetrics>,
456        actor_id: ActorId,
457        fragment_id: FragmentId,
458        side: &'static str,
459    ) -> Self {
460        // TODO: unify pk encoding with state table.
461        let pk_data_types = state_pk_indices
462            .iter()
463            .map(|i| state_all_data_types[*i].clone())
464            .collect();
465        let pk_serializer = OrderedRowSerde::new(
466            pk_data_types,
467            vec![OrderType::ascending(); state_pk_indices.len()],
468        );
469
470        let inequality_key_desc = inequality_key_idx.map(|idx| {
471            let serializer = OrderedRowSerde::new(
472                vec![state_all_data_types[idx].clone()],
473                vec![OrderType::ascending()],
474            );
475            InequalityKeyDesc { idx, serializer }
476        });
477
478        let join_table_id = state_table.table_id();
479        let state = TableInner {
480            pk_indices: state_pk_indices,
481            join_key_indices: state_join_key_indices,
482            order_key_indices: state_table.pk_indices().to_vec(),
483            degree_inequality_idx: inequality_key_idx,
484            table: state_table,
485        };
486
487        let need_degree_table = degree_state.is_some();
488
489        let metrics_info = MetricsInfo::new(
490            metrics.clone(),
491            join_table_id,
492            actor_id,
493            format!("hash join {}", side),
494        );
495
496        let cache = ManagedLruCache::unbounded_with_hasher(
497            watermark_sequence,
498            metrics_info,
499            PrecomputedBuildHasher,
500        );
501
502        Self {
503            inner: cache,
504            join_key_data_types,
505            null_matched,
506            pk_serializer,
507            state,
508            degree_state,
509            need_degree_table,
510            pk_contained_in_jk,
511            inequality_key_desc,
512            metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
513            _marker: std::marker::PhantomData,
514        }
515    }
516
517    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
518        self.state.table.init_epoch(epoch).await?;
519        if let Some(degree_state) = &mut self.degree_state {
520            degree_state.table.init_epoch(epoch).await?;
521        }
522        Ok(())
523    }
524}
525
526impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
527    pub async fn post_yield_barrier(
528        self,
529        vnode_bitmap: Option<Arc<Bitmap>>,
530    ) -> StreamExecutorResult<Option<bool>> {
531        let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
532        if let Some(degree_state) = self.degree_state {
533            let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
534        }
535        let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
536        if cache_may_stale.unwrap_or(false) {
537            self.inner.clear();
538        }
539        Ok(cache_may_stale)
540    }
541}
542impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
543    pub fn update_watermark(&mut self, watermark: ScalarImpl) {
544        // TODO: remove data in cache.
545        self.state.table.update_watermark(watermark.clone());
546        if let Some(degree_state) = &mut self.degree_state {
547            degree_state.table.update_watermark(watermark);
548        }
549    }
550
551    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
552    /// `update_state` after some operations to put the state back.
553    ///
554    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
555    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
556    /// returned.
557    ///
558    /// Note: This will NOT remove anything from remote storage.
559    pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
560        self.metrics.total_lookup_count += 1;
561        if self.inner.contains(key) {
562            tracing::trace!("hit cache for join key: {:?}", key);
563            // Do not update the LRU statistics here with `peek_mut` since we will put the state
564            // back.
565            let mut state = self.inner.peek_mut(key).expect("checked contains");
566            CacheResult::Hit(state.take())
567        } else {
568            self.metrics.lookup_miss_count += 1;
569            tracing::trace!("miss cache for join key: {:?}", key);
570            CacheResult::Miss
571        }
572    }
573
574    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
575    /// `update_state` after some operations to put the state back.
576    ///
577    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
578    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
579    /// returned.
580    ///
581    /// Note: This will NOT remove anything from remote storage.
582    pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
583        self.metrics.total_lookup_count += 1;
584        let state = if self.inner.contains(key) {
585            // Do not update the LRU statistics here with `peek_mut` since we will put the state
586            // back.
587            let mut state = self.inner.peek_mut(key).unwrap();
588            state.take()
589        } else {
590            self.metrics.lookup_miss_count += 1;
591            self.fetch_cached_state(key).await?.into()
592        };
593        Ok(state)
594    }
595
596    /// Fetch cache from the state store. Should only be called if the key does not exist in memory.
597    /// Will return a empty `JoinEntryState` even when state does not exist in remote.
598    async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
599        let key = key.deserialize(&self.join_key_data_types)?;
600
601        let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
602
603        if self.need_degree_table {
604            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
605                &(Bound::Unbounded, Bound::Unbounded);
606            let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
607                &key,
608                sub_range,
609                PrefetchOptions::default(),
610            );
611            let degree_state = self.degree_state.as_ref().unwrap();
612            let degree_col_idx = degree_col_idx_in_row(&degree_state.table);
613            let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
614                &key,
615                sub_range,
616                PrefetchOptions::default(),
617            );
618
619            let (table_iter, degree_table_iter) =
620                try_join(table_iter_fut, degree_table_iter_fut).await?;
621
622            let mut pinned_table_iter = std::pin::pin!(table_iter);
623            let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
624
625            // For better tolerating inconsistent stream, we have to first buffer all rows and
626            // degree rows, and check the number of them, then iterate on them.
627            let mut rows = vec![];
628            let mut degree_rows = vec![];
629            let mut inconsistency_happened = false;
630            loop {
631                let (row, degree_row) =
632                    join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
633                let (row, degree_row) = match (row, degree_row) {
634                    (None, None) => break,
635                    (None, Some(_)) => {
636                        inconsistency_happened = true;
637                        consistency_panic!(
638                            "mismatched row and degree table of join key: {:?}, degree table has more rows",
639                            &key
640                        );
641                        break;
642                    }
643                    (Some(_), None) => {
644                        inconsistency_happened = true;
645                        consistency_panic!(
646                            "mismatched row and degree table of join key: {:?}, input table has more rows",
647                            &key
648                        );
649                        break;
650                    }
651                    (Some(r), Some(d)) => (r, d),
652                };
653
654                let row = row?;
655                let degree_row = degree_row?;
656                rows.push(row);
657                degree_rows.push(degree_row);
658            }
659
660            if inconsistency_happened {
661                // Pk-based row-degree pairing.
662                assert_ne!(rows.len(), degree_rows.len());
663
664                let row_iter = stream::iter(rows.into_iter()).peekable();
665                let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
666                pin_mut!(row_iter);
667                pin_mut!(degree_row_iter);
668
669                loop {
670                    match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
671                        (None, _) | (_, None) => break,
672                        (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
673                            Ordering::Greater => {
674                                degree_row_iter.next().await;
675                            }
676                            Ordering::Less => {
677                                row_iter.next().await;
678                            }
679                            Ordering::Equal => {
680                                let row =
681                                    row_iter.next().await.expect("we matched some(row) above");
682                                let degree_row = degree_row_iter
683                                    .next()
684                                    .await
685                                    .expect("we matched some(degree_row) above");
686                                let pk = row
687                                    .as_ref()
688                                    .project(&self.state.pk_indices)
689                                    .memcmp_serialize(&self.pk_serializer);
690                                let degree_i64 = degree_row
691                                    .datum_at(degree_col_idx)
692                                    .expect("degree should not be NULL");
693                                let inequality_key = self
694                                    .inequality_key_desc
695                                    .as_ref()
696                                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
697                                entry_state
698                                    .insert(
699                                        pk,
700                                        E::encode(&JoinRow::new(
701                                            row.row(),
702                                            degree_i64.into_int64() as u64,
703                                        )),
704                                        inequality_key,
705                                    )
706                                    .with_context(|| self.state.error_context(row.row()))?;
707                            }
708                        },
709                    }
710                }
711            } else {
712                // 1 to 1 row-degree pairing.
713                // Actually it's possible that both the input data table and the degree table missed
714                // some equal number of rows, but let's ignore this case because it should be rare.
715
716                assert_eq!(rows.len(), degree_rows.len());
717
718                #[for_await]
719                for (row, degree_row) in
720                    stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
721                {
722                    let row: KeyedRow<_> = row;
723                    let degree_row: KeyedRow<_> = degree_row;
724
725                    let pk1 = row.key();
726                    let pk2 = degree_row.key();
727                    debug_assert_eq!(
728                        pk1, pk2,
729                        "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
730                    );
731                    let pk = row
732                        .as_ref()
733                        .project(&self.state.pk_indices)
734                        .memcmp_serialize(&self.pk_serializer);
735                    let inequality_key = self
736                        .inequality_key_desc
737                        .as_ref()
738                        .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
739                    let degree_i64 = degree_row
740                        .datum_at(degree_col_idx)
741                        .expect("degree should not be NULL");
742                    entry_state
743                        .insert(
744                            pk,
745                            E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
746                            inequality_key,
747                        )
748                        .with_context(|| self.state.error_context(row.row()))?;
749                }
750            }
751        } else {
752            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
753                &(Bound::Unbounded, Bound::Unbounded);
754            let table_iter = self
755                .state
756                .table
757                .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
758                .await?;
759
760            #[for_await]
761            for entry in table_iter {
762                let row: KeyedRow<_> = entry?;
763                let pk = row
764                    .as_ref()
765                    .project(&self.state.pk_indices)
766                    .memcmp_serialize(&self.pk_serializer);
767                let inequality_key = self
768                    .inequality_key_desc
769                    .as_ref()
770                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
771                entry_state
772                    .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
773                    .with_context(|| self.state.error_context(row.row()))?;
774            }
775        };
776
777        Ok(entry_state)
778    }
779
780    pub async fn flush(
781        &mut self,
782        epoch: EpochPair,
783    ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
784        self.metrics.flush();
785        let state_post_commit = self.state.table.commit(epoch).await?;
786        let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
787            Some(degree_state.table.commit(epoch).await?)
788        } else {
789            None
790        };
791        Ok(JoinHashMapPostCommit {
792            state: state_post_commit,
793            degree_state: degree_state_post_commit,
794            inner: &mut self.inner,
795        })
796    }
797
798    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
799        self.state.table.try_flush().await?;
800        if let Some(degree_state) = &mut self.degree_state {
801            degree_state.table.try_flush().await?;
802        }
803        Ok(())
804    }
805
806    pub fn insert_handle_degree(
807        &mut self,
808        key: &K,
809        value: JoinRow<impl Row>,
810    ) -> StreamExecutorResult<()> {
811        if self.need_degree_table {
812            self.insert(key, value)
813        } else {
814            self.insert_row(key, value.row)
815        }
816    }
817
818    /// Insert a join row
819    pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
820        let pk = self.serialize_pk_from_row(&value.row);
821
822        let inequality_key = self
823            .inequality_key_desc
824            .as_ref()
825            .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
826
827        // TODO(yuhao): avoid this `contains`.
828        // https://github.com/risingwavelabs/risingwave/issues/9233
829        if self.inner.contains(key) {
830            // Update cache
831            let mut entry = self.inner.get_mut(key).expect("checked contains");
832            entry
833                .insert(pk, E::encode(&value), inequality_key)
834                .with_context(|| self.state.error_context(&value.row))?;
835        } else if self.pk_contained_in_jk {
836            // Refill cache when the join key exist in neither cache or storage.
837            self.metrics.insert_cache_miss_count += 1;
838            let mut entry: JoinEntryState<E> = JoinEntryState::default();
839            entry
840                .insert(pk, E::encode(&value), inequality_key)
841                .with_context(|| self.state.error_context(&value.row))?;
842            self.update_state(key, entry.into());
843        }
844
845        // Update the flush buffer.
846        if let Some(degree_state) = self.degree_state.as_mut() {
847            let (row, degree) = value.to_table_rows(
848                &self.state.order_key_indices,
849                degree_state.degree_inequality_idx,
850            );
851            self.state.table.insert(row);
852            degree_state.table.insert(degree);
853        } else {
854            self.state.table.insert(value.row);
855        }
856        Ok(())
857    }
858
859    /// Insert a row.
860    /// Used when the side does not need to update degree.
861    pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
862        let join_row = JoinRow::new(&value, 0);
863        self.insert(key, join_row)?;
864        Ok(())
865    }
866
867    pub fn delete_row_in_mem(&mut self, key: &K, value: &impl Row) -> StreamExecutorResult<()> {
868        if let Some(mut entry) = self.inner.get_mut(key) {
869            let pk = (&value)
870                .project(&self.state.pk_indices)
871                .memcmp_serialize(&self.pk_serializer);
872
873            let inequality_key = self
874                .inequality_key_desc
875                .as_ref()
876                .map(|desc| desc.serialize_inequal_key_from_row(value));
877            entry
878                .remove(pk, inequality_key.as_ref())
879                .with_context(|| self.state.error_context(&value))?;
880        }
881        Ok(())
882    }
883
884    pub fn delete_handle_degree(
885        &mut self,
886        key: &K,
887        value: JoinRow<impl Row>,
888    ) -> StreamExecutorResult<()> {
889        if self.need_degree_table {
890            self.delete(key, value)
891        } else {
892            self.delete_row(key, value.row)
893        }
894    }
895
896    /// Delete a join row
897    pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
898        self.delete_row_in_mem(key, &value.row)?;
899
900        // If no cache maintained, only update the state table.
901        let degree_state = self.degree_state.as_mut().expect("degree table missing");
902        let (row, degree) = value.to_table_rows(
903            &self.state.order_key_indices,
904            degree_state.degree_inequality_idx,
905        );
906        self.state.table.delete(row);
907        degree_state.table.delete(degree);
908        Ok(())
909    }
910
911    /// Delete a row
912    /// Used when the side does not need to update degree.
913    pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
914        self.delete_row_in_mem(key, &value)?;
915
916        // If no cache maintained, only update the state table.
917        self.state.table.delete(value);
918        Ok(())
919    }
920
921    /// Update a [`JoinEntryState`] into the hash table.
922    pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
923        self.inner.put(key.clone(), HashValueWrapper(Some(state)));
924    }
925
926    /// Evict the cache.
927    pub fn evict(&mut self) {
928        self.inner.evict();
929    }
930
931    /// Cached entry count for this hash table.
932    pub fn entry_count(&self) -> usize {
933        self.inner.len()
934    }
935
936    pub fn null_matched(&self) -> &K::Bitmap {
937        &self.null_matched
938    }
939
940    pub fn table_id(&self) -> TableId {
941        self.state.table.table_id()
942    }
943
944    pub fn join_key_data_types(&self) -> &[DataType] {
945        &self.join_key_data_types
946    }
947
948    /// Return true if the inequality key is null.
949    /// # Panics
950    /// Panics if the inequality key is not set.
951    pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
952        let desc = self
953            .inequality_key_desc
954            .as_ref()
955            .expect("inequality key desc missing");
956        row.datum_at(desc.idx).is_none()
957    }
958
959    /// Serialize the inequality key from a row.
960    /// # Panics
961    /// Panics if the inequality key is not set.
962    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
963        self.inequality_key_desc
964            .as_ref()
965            .expect("inequality key desc missing")
966            .serialize_inequal_key_from_row(&row)
967    }
968
969    pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
970        row.project(&self.state.pk_indices)
971            .memcmp_serialize(&self.pk_serializer)
972    }
973}
974
975#[must_use]
976pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
977    state: StateTablePostCommit<'a, S>,
978    degree_state: Option<StateTablePostCommit<'a, S>>,
979    inner: &'a mut JoinHashMapInner<K, E>,
980}
981
982use risingwave_common::catalog::TableId;
983use risingwave_common_estimate_size::KvSize;
984use thiserror::Error;
985
986use super::*;
987use crate::executor::prelude::{Stream, try_stream};
988
989/// We manages a `HashMap` in memory for all entries belonging to a join key.
990/// When evicted, `cached` does not hold any entries.
991///
992/// If a `JoinEntryState` exists for a join key, the all records under this
993/// join key will be presented in the cache.
994#[derive(Default)]
995pub struct JoinEntryState<E: JoinEncoding> {
996    /// The full copy of the state.
997    cached: JoinRowSet<PkType, E::EncodedRow>,
998    /// Index used for AS OF join. The key is inequal column value. The value is the primary key in `cached`.
999    inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
1000    kv_heap_size: KvSize,
1001}
1002
1003impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
1004    fn estimated_heap_size(&self) -> usize {
1005        // TODO: Add btreemap internal size.
1006        // https://github.com/risingwavelabs/risingwave/issues/9713
1007        self.kv_heap_size.size()
1008    }
1009}
1010
1011#[derive(Error, Debug)]
1012pub enum JoinEntryError {
1013    #[error("double inserting a join state entry")]
1014    Occupied,
1015    #[error("removing a join state entry but it is not in the cache")]
1016    Remove,
1017    #[error("retrieving a pk from the inequality index but it is not in the cache")]
1018    InequalIndex,
1019}
1020
1021impl<E: JoinEncoding> JoinEntryState<E> {
1022    /// Insert into the cache.
1023    pub fn insert(
1024        &mut self,
1025        key: PkType,
1026        value: E::EncodedRow,
1027        inequality_key: Option<InequalKeyType>,
1028    ) -> Result<&mut E::EncodedRow, JoinEntryError> {
1029        let mut removed = false;
1030        if !enable_strict_consistency() {
1031            // strict consistency is off, let's remove existing (if any) first
1032            if let Some(old_value) = self.cached.remove(&key) {
1033                if let Some(inequality_key) = inequality_key.as_ref() {
1034                    self.remove_pk_from_inequality_index(&key, inequality_key);
1035                }
1036                self.kv_heap_size.sub(&key, &old_value);
1037                removed = true;
1038            }
1039        }
1040
1041        self.kv_heap_size.add(&key, &value);
1042
1043        if let Some(inequality_key) = inequality_key {
1044            self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1045        }
1046        let ret = self.cached.try_insert(key.clone(), value);
1047
1048        if !enable_strict_consistency() {
1049            assert!(ret.is_ok(), "we have removed existing entry, if any");
1050            if removed {
1051                // if not silent, we should log the error
1052                consistency_error!(?key, "double inserting a join state entry");
1053            }
1054        }
1055
1056        ret.map_err(|_| JoinEntryError::Occupied)
1057    }
1058
1059    /// Delete from the cache.
1060    pub fn remove(
1061        &mut self,
1062        pk: PkType,
1063        inequality_key: Option<&InequalKeyType>,
1064    ) -> Result<(), JoinEntryError> {
1065        if let Some(value) = self.cached.remove(&pk) {
1066            self.kv_heap_size.sub(&pk, &value);
1067            if let Some(inequality_key) = inequality_key {
1068                self.remove_pk_from_inequality_index(&pk, inequality_key);
1069            }
1070            Ok(())
1071        } else if enable_strict_consistency() {
1072            Err(JoinEntryError::Remove)
1073        } else {
1074            consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1075            Ok(())
1076        }
1077    }
1078
1079    fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1080        if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1081            if pk_set.remove(pk).is_none() {
1082                if enable_strict_consistency() {
1083                    panic!("removing a pk that it not in the inequality index");
1084                } else {
1085                    consistency_error!(?pk, "removing a pk that it not in the inequality index");
1086                };
1087            } else {
1088                self.kv_heap_size.sub(pk, &());
1089            }
1090            if pk_set.is_empty() {
1091                self.inequality_index.remove(inequality_key);
1092            }
1093        }
1094    }
1095
1096    fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1097        if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1098            let pk_size = pk.estimated_size();
1099            if pk_set.try_insert(pk, ()).is_err() {
1100                if enable_strict_consistency() {
1101                    panic!("inserting a pk that it already in the inequality index");
1102                } else {
1103                    consistency_error!("inserting a pk that it already in the inequality index");
1104                };
1105            } else {
1106                self.kv_heap_size.add_size(pk_size);
1107            }
1108        } else {
1109            let mut pk_set = JoinRowSet::default();
1110            pk_set.try_insert(pk, ()).expect("pk set should be empty");
1111            self.inequality_index
1112                .try_insert(inequality_key, pk_set)
1113                .expect("pk set should be empty");
1114        }
1115    }
1116
1117    pub fn get(
1118        &self,
1119        pk: &PkType,
1120        data_types: &[DataType],
1121    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1122        self.cached
1123            .get(pk)
1124            .map(|encoded| encoded.decode(data_types))
1125    }
1126
1127    /// Note: the first item in the tuple is the mutable reference to the value in this entry, while
1128    /// the second item is the decoded value. To mutate the degree, one **must not** forget to apply
1129    /// the changes to the first item.
1130    ///
1131    /// WARNING: Should not change the heap size of `StateValueType` with the mutable reference.
1132    pub fn values_mut<'a>(
1133        &'a mut self,
1134        data_types: &'a [DataType],
1135    ) -> impl Iterator<
1136        Item = (
1137            &'a mut E::EncodedRow,
1138            StreamExecutorResult<JoinRow<E::DecodedRow>>,
1139        ),
1140    > + 'a {
1141        self.cached.values_mut().map(|encoded| {
1142            let decoded = encoded.decode(data_types);
1143            (encoded, decoded)
1144        })
1145    }
1146
1147    pub fn len(&self) -> usize {
1148        self.cached.len()
1149    }
1150
1151    /// Range scan the cache using the inequality index.
1152    pub fn range_by_inequality<'a, R>(
1153        &'a self,
1154        range: R,
1155        data_types: &'a [DataType],
1156    ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<E::DecodedRow>>> + 'a
1157    where
1158        R: RangeBounds<InequalKeyType> + 'a,
1159    {
1160        self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1161            pk_set
1162                .keys()
1163                .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1164        })
1165    }
1166
1167    /// Get the records whose inequality key upper bound satisfy the given bound.
1168    pub fn upper_bound_by_inequality<'a>(
1169        &'a self,
1170        bound: Bound<&InequalKeyType>,
1171        data_types: &'a [DataType],
1172    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1173        if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1174            if let Some(pk) = pk_set.first_key_sorted() {
1175                self.get_by_indexed_pk(pk, data_types)
1176            } else {
1177                panic!("pk set for a index record must has at least one element");
1178            }
1179        } else {
1180            None
1181        }
1182    }
1183
1184    pub fn get_by_indexed_pk(
1185        &self,
1186        pk: &PkType,
1187        data_types: &[DataType],
1188    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>>
1189where {
1190        if let Some(value) = self.cached.get(pk) {
1191            Some(value.decode(data_types))
1192        } else if enable_strict_consistency() {
1193            Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1194        } else {
1195            consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1196            None
1197        }
1198    }
1199
1200    /// Get the records whose inequality key lower bound satisfy the given bound.
1201    pub fn lower_bound_by_inequality<'a>(
1202        &'a self,
1203        bound: Bound<&InequalKeyType>,
1204        data_types: &'a [DataType],
1205    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1206        if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1207            if let Some(pk) = pk_set.first_key_sorted() {
1208                self.get_by_indexed_pk(pk, data_types)
1209            } else {
1210                panic!("pk set for a index record must has at least one element");
1211            }
1212        } else {
1213            None
1214        }
1215    }
1216
1217    pub fn get_first_by_inequality<'a>(
1218        &'a self,
1219        inequality_key: &InequalKeyType,
1220        data_types: &'a [DataType],
1221    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1222        if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1223            if let Some(pk) = pk_set.first_key_sorted() {
1224                self.get_by_indexed_pk(pk, data_types)
1225            } else {
1226                panic!("pk set for a index record must has at least one element");
1227            }
1228        } else {
1229            None
1230        }
1231    }
1232
1233    pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1234        &self.inequality_index
1235    }
1236}
1237
1238#[cfg(test)]
1239mod tests {
1240    use itertools::Itertools;
1241    use risingwave_common::array::*;
1242    use risingwave_common::types::ScalarRefImpl;
1243    use risingwave_common::util::iter_util::ZipEqDebug;
1244
1245    use super::*;
1246    use crate::executor::MemoryEncoding;
1247
1248    fn insert_chunk<E: JoinEncoding>(
1249        managed_state: &mut JoinEntryState<E>,
1250        pk_indices: &[usize],
1251        col_types: &[DataType],
1252        inequality_key_idx: Option<usize>,
1253        data_chunk: &DataChunk,
1254    ) {
1255        let pk_col_type = pk_indices
1256            .iter()
1257            .map(|idx| col_types[*idx].clone())
1258            .collect_vec();
1259        let pk_serializer =
1260            OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1261        let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1262        let inequality_key_serializer = inequality_key_type
1263            .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1264        for row_ref in data_chunk.rows() {
1265            let row: OwnedRow = row_ref.into_owned_row();
1266            let value_indices = (0..row.len() - 1).collect_vec();
1267            let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1268            // Pk is only a `i64` here, so encoding method does not matter.
1269            let pk = OwnedRow::new(pk)
1270                .project(&value_indices)
1271                .memcmp_serialize(&pk_serializer);
1272            let inequality_key = inequality_key_idx.map(|idx| {
1273                (&row)
1274                    .project(&[idx])
1275                    .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1276            });
1277            let join_row = JoinRow { row, degree: 0 };
1278            managed_state
1279                .insert(pk, E::encode(&join_row), inequality_key)
1280                .unwrap();
1281        }
1282    }
1283
1284    fn check<E: JoinEncoding>(
1285        managed_state: &mut JoinEntryState<E>,
1286        col_types: &[DataType],
1287        col1: &[i64],
1288        col2: &[i64],
1289    ) {
1290        for ((_, matched_row), (d1, d2)) in managed_state
1291            .values_mut(col_types)
1292            .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1293        {
1294            let matched_row = matched_row.unwrap();
1295            assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
1296            assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
1297            assert_eq!(matched_row.degree, 0);
1298        }
1299    }
1300
1301    #[tokio::test]
1302    async fn test_managed_join_state() {
1303        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1304        let col_types = vec![DataType::Int64, DataType::Int64];
1305        let pk_indices = [0];
1306
1307        let col1 = [3, 2, 1];
1308        let col2 = [4, 5, 6];
1309        let data_chunk1 = DataChunk::from_pretty(
1310            "I I
1311             3 4
1312             2 5
1313             1 6",
1314        );
1315
1316        // `Vec` in state
1317        insert_chunk::<MemoryEncoding>(
1318            &mut managed_state,
1319            &pk_indices,
1320            &col_types,
1321            None,
1322            &data_chunk1,
1323        );
1324        check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1325
1326        // `BtreeMap` in state
1327        let col1 = [1, 2, 3, 4, 5];
1328        let col2 = [6, 5, 4, 9, 8];
1329        let data_chunk2 = DataChunk::from_pretty(
1330            "I I
1331             5 8
1332             4 9",
1333        );
1334        insert_chunk(
1335            &mut managed_state,
1336            &pk_indices,
1337            &col_types,
1338            None,
1339            &data_chunk2,
1340        );
1341        check(&mut managed_state, &col_types, &col1, &col2);
1342    }
1343
1344    #[tokio::test]
1345    async fn test_managed_join_state_w_inequality_index() {
1346        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1347        let col_types = vec![DataType::Int64, DataType::Int64];
1348        let pk_indices = [0];
1349        let inequality_key_idx = Some(1);
1350        let inequality_key_serializer =
1351            OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1352
1353        let col1 = [3, 2, 1];
1354        let col2 = [4, 5, 5];
1355        let data_chunk1 = DataChunk::from_pretty(
1356            "I I
1357             3 4
1358             2 5
1359             1 5",
1360        );
1361
1362        // `Vec` in state
1363        insert_chunk(
1364            &mut managed_state,
1365            &pk_indices,
1366            &col_types,
1367            inequality_key_idx,
1368            &data_chunk1,
1369        );
1370        check(&mut managed_state, &col_types, &col1, &col2);
1371        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1372            .memcmp_serialize(&inequality_key_serializer);
1373        let row = managed_state
1374            .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1375            .unwrap()
1376            .unwrap();
1377        assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1378        let row = managed_state
1379            .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1380            .unwrap()
1381            .unwrap();
1382        assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1383
1384        // `BtreeMap` in state
1385        let col1 = [1, 2, 3, 4, 5];
1386        let col2 = [5, 5, 4, 4, 8];
1387        let data_chunk2 = DataChunk::from_pretty(
1388            "I I
1389             5 8
1390             4 4",
1391        );
1392        insert_chunk(
1393            &mut managed_state,
1394            &pk_indices,
1395            &col_types,
1396            inequality_key_idx,
1397            &data_chunk2,
1398        );
1399        check(&mut managed_state, &col_types, &col1, &col2);
1400
1401        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1402            .memcmp_serialize(&inequality_key_serializer);
1403        let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1404        assert!(row.is_none());
1405    }
1406}