risingwave_stream/executor/join/
hash_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use std::alloc::Global;
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
25use risingwave_common::bitmap::Bitmap;
26use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
27use risingwave_common::metrics::LabelGuardedIntCounter;
28use risingwave_common::row::{OwnedRow, Row, RowExt, once};
29use risingwave_common::types::{DataType, ScalarImpl};
30use risingwave_common::util::epoch::EpochPair;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common::util::row_serde::OrderedRowSerde;
33use risingwave_common::util::sort_util::OrderType;
34use risingwave_common_estimate_size::EstimateSize;
35use risingwave_storage::StateStore;
36use risingwave_storage::store::PrefetchOptions;
37use thiserror_ext::AsReport;
38
39use super::row::{CachedJoinRow, DegreeType};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::error::StreamExecutorResult;
45use crate::executor::join::row::JoinRow;
46use crate::executor::monitor::StreamingMetrics;
47use crate::executor::{JoinEncoding, StreamExecutorError};
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50/// Memcomparable encoding.
51type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type HashValueType<E> = Box<JoinEntryState<E>>;
55
56impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
57    fn estimated_heap_size(&self) -> usize {
58        self.as_ref().estimated_heap_size()
59    }
60}
61
62/// The wrapper for [`JoinEntryState`] which should be `Some` most of the time in the hash table.
63///
64/// When the executor is operating on the specific entry of the map, it can hold the ownership of
65/// the entry by taking the value out of the `Option`, instead of holding a mutable reference to the
66/// map, which can make the compiler happy.
67struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
68
69pub(crate) enum CacheResult<E: JoinEncoding> {
70    NeverMatch,            // Will never match, will not be in cache at all.
71    Miss,                  // Cache-miss
72    Hit(HashValueType<E>), // Cache-hit
73}
74
75impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
76    fn estimated_heap_size(&self) -> usize {
77        self.0.estimated_heap_size()
78    }
79}
80
81impl<E: JoinEncoding> HashValueWrapper<E> {
82    const MESSAGE: &'static str = "the state should always be `Some`";
83
84    /// Take the value out of the wrapper. Panic if the value is `None`.
85    pub fn take(&mut self) -> HashValueType<E> {
86        self.0.take().expect(Self::MESSAGE)
87    }
88}
89
90impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
91    type Target = HashValueType<E>;
92
93    fn deref(&self) -> &Self::Target {
94        self.0.as_ref().expect(Self::MESSAGE)
95    }
96}
97
98impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
99    fn deref_mut(&mut self) -> &mut Self::Target {
100        self.0.as_mut().expect(Self::MESSAGE)
101    }
102}
103
104type JoinHashMapInner<K, E> =
105    ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher, SharedStatsAlloc<Global>>;
106
107pub struct JoinHashMapMetrics {
108    /// Basic information
109    /// How many times have we hit the cache of join executor
110    lookup_miss_count: usize,
111    total_lookup_count: usize,
112    /// How many times have we miss the cache when insert row
113    insert_cache_miss_count: usize,
114
115    // Metrics
116    join_lookup_total_count_metric: LabelGuardedIntCounter,
117    join_lookup_miss_count_metric: LabelGuardedIntCounter,
118    join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
119}
120
121impl JoinHashMapMetrics {
122    pub fn new(
123        metrics: &StreamingMetrics,
124        actor_id: ActorId,
125        fragment_id: FragmentId,
126        side: &'static str,
127        join_table_id: u32,
128    ) -> Self {
129        let actor_id = actor_id.to_string();
130        let fragment_id = fragment_id.to_string();
131        let join_table_id = join_table_id.to_string();
132        let join_lookup_total_count_metric = metrics
133            .join_lookup_total_count
134            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
135        let join_lookup_miss_count_metric = metrics
136            .join_lookup_miss_count
137            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
138        let join_insert_cache_miss_count_metrics = metrics
139            .join_insert_cache_miss_count
140            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
141
142        Self {
143            lookup_miss_count: 0,
144            total_lookup_count: 0,
145            insert_cache_miss_count: 0,
146            join_lookup_total_count_metric,
147            join_lookup_miss_count_metric,
148            join_insert_cache_miss_count_metrics,
149        }
150    }
151
152    pub fn flush(&mut self) {
153        self.join_lookup_total_count_metric
154            .inc_by(self.total_lookup_count as u64);
155        self.join_lookup_miss_count_metric
156            .inc_by(self.lookup_miss_count as u64);
157        self.join_insert_cache_miss_count_metrics
158            .inc_by(self.insert_cache_miss_count as u64);
159        self.total_lookup_count = 0;
160        self.lookup_miss_count = 0;
161        self.insert_cache_miss_count = 0;
162    }
163}
164
165/// Inequality key description for `AsOf` join.
166struct InequalityKeyDesc {
167    idx: usize,
168    serializer: OrderedRowSerde,
169}
170
171impl InequalityKeyDesc {
172    /// Serialize the inequality key from a row.
173    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
174        let indices = vec![self.idx];
175        let inequality_key = row.project(&indices);
176        inequality_key.memcmp_serialize(&self.serializer)
177    }
178}
179
180pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
181    /// Store the join states.
182    inner: JoinHashMapInner<K, E>,
183    /// Data types of the join key columns
184    join_key_data_types: Vec<DataType>,
185    /// Null safe bitmap for each join pair
186    null_matched: K::Bitmap,
187    /// The memcomparable serializer of primary key.
188    pk_serializer: OrderedRowSerde,
189    /// State table. Contains the data from upstream.
190    state: TableInner<S>,
191    /// Degree table.
192    ///
193    /// The degree is generated from the hash join executor.
194    /// Each row in `state` has a corresponding degree in `degree state`.
195    /// A degree value `d` in for a row means the row has `d` matched row in the other join side.
196    ///
197    /// It will only be used when needed in a side.
198    ///
199    /// - Full Outer: both side
200    /// - Left Outer/Semi/Anti: left side
201    /// - Right Outer/Semi/Anti: right side
202    /// - Inner: neither side.
203    ///
204    /// Should be set to `None` if `need_degree_table` was set to `false`.
205    ///
206    /// The degree of each row will tell us if we need to emit `NULL` for the row.
207    /// For instance, given `lhs LEFT JOIN rhs`,
208    /// If the degree of a row in `lhs` is 0, it means the row does not have a match in `rhs`.
209    /// If the degree of a row in `lhs` is 2, it means the row has two matches in `rhs`.
210    /// Now, when emitting the result of the join, we need to emit `NULL` for the row in `lhs` if
211    /// the degree is 0.
212    ///
213    /// Why don't just use a boolean value instead of a degree count?
214    /// Consider the case where we delete a matched record from `rhs`.
215    /// Since we can delete a record,
216    /// there must have been a record in `rhs` that matched the record in `lhs`.
217    /// So this value is `true`.
218    /// But we don't know how many records are matched after removing this record,
219    /// since we only stored a boolean value rather than the count.
220    /// Hence we need to store the count of matched records.
221    degree_state: Option<TableInner<S>>,
222    // TODO(kwannoel): Make this `const` instead.
223    /// If degree table is need
224    need_degree_table: bool,
225    /// Pk is part of the join key.
226    pk_contained_in_jk: bool,
227    /// Inequality key description for `AsOf` join.
228    inequality_key_desc: Option<InequalityKeyDesc>,
229    /// Metrics of the hash map
230    metrics: JoinHashMapMetrics,
231    _marker: std::marker::PhantomData<E>,
232}
233
234impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
235    pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
236        (&self.state.order_key_indices, &mut self.degree_state)
237    }
238
239    /// NOTE(kwannoel): This allows us to concurrently stream records from the `state_table`,
240    /// and update the degree table, without using `unsafe` code.
241    ///
242    /// This is because we obtain separate references to separate parts of the `JoinHashMap`,
243    /// instead of reusing the same reference to `JoinHashMap` for concurrent read access to `state_table`,
244    /// and write access to the degree table.
245    pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
246        &'a mut self,
247        key: &'a K,
248    ) -> StreamExecutorResult<(
249        impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
250        &'a [usize],
251        &'a mut Option<TableInner<S>>,
252    )> {
253        let degree_state = &mut self.degree_state;
254        let (order_key_indices, pk_indices, state_table) = (
255            &self.state.order_key_indices,
256            &self.state.pk_indices,
257            &mut self.state.table,
258        );
259        let degrees = if let Some(degree_state) = degree_state {
260            Some(fetch_degrees(key, &self.join_key_data_types, &degree_state.table).await?)
261        } else {
262            None
263        };
264        let stream = into_stream(
265            &self.join_key_data_types,
266            pk_indices,
267            &self.pk_serializer,
268            state_table,
269            key,
270            degrees,
271        );
272        Ok((stream, order_key_indices, &mut self.degree_state))
273    }
274}
275
276#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
277pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
278    join_key_data_types: &'a [DataType],
279    pk_indices: &'a [usize],
280    pk_serializer: &'a OrderedRowSerde,
281    state_table: &'a StateTable<S>,
282    key: &'a K,
283    degrees: Option<Vec<DegreeType>>,
284) {
285    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
286    let decoded_key = key.deserialize(join_key_data_types)?;
287    let table_iter = state_table
288        .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
289        .await?;
290
291    #[for_await]
292    for (i, entry) in table_iter.enumerate() {
293        let encoded_row = entry?;
294        let encoded_pk = encoded_row
295            .as_ref()
296            .project(pk_indices)
297            .memcmp_serialize(pk_serializer);
298        let join_row = JoinRow::new(
299            encoded_row.into_owned_row(),
300            degrees.as_ref().map_or(0, |d| d[i]),
301        );
302        yield (encoded_pk, join_row);
303    }
304}
305
306/// We use this to fetch ALL degrees into memory.
307/// We use this instead of a streaming interface.
308/// It is necessary because we must update the `degree_state_table` concurrently.
309/// If we obtain the degrees in a stream,
310/// we will need to hold an immutable reference to the state table for the entire lifetime,
311/// preventing us from concurrently updating the state table.
312///
313/// The cost of fetching all degrees upfront is acceptable. We currently already do so
314/// in `fetch_cached_state`.
315/// The memory use should be limited since we only store a u64.
316///
317/// Let's say we have amplification of 1B, we will have 1B * 8 bytes ~= 8GB
318///
319/// We can also have further optimization, to permit breaking the streaming update,
320/// to flush the in-memory degrees, if this is proven to have high memory consumption.
321///
322/// TODO(kwannoel): Perhaps we can cache these separately from matched rows too.
323/// Because matched rows may occupy a larger capacity.
324///
325/// Argument for this:
326/// We only hit this when cache miss. When cache miss, we will have this as one off cost.
327/// Keeping this cached separately from matched rows is beneficial.
328/// Then we can evict matched rows, without touching the degrees.
329async fn fetch_degrees<K: HashKey, S: StateStore>(
330    key: &K,
331    join_key_data_types: &[DataType],
332    degree_state_table: &StateTable<S>,
333) -> StreamExecutorResult<Vec<DegreeType>> {
334    let key = key.deserialize(join_key_data_types)?;
335    let mut degrees = vec![];
336    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
337    let table_iter = degree_state_table
338        .iter_with_prefix(key, sub_range, PrefetchOptions::default())
339        .await?;
340    #[for_await]
341    for entry in table_iter {
342        let degree_row = entry?;
343        let degree_i64 = degree_row
344            .datum_at(degree_row.len() - 1)
345            .expect("degree should not be NULL");
346        degrees.push(degree_i64.into_int64() as u64);
347    }
348    Ok(degrees)
349}
350
351// NOTE(kwannoel): This is not really specific to `TableInner`.
352// A degree table is `TableInner`, a `TableInner` might not be a degree table.
353// Hence we don't specify it in its impl block.
354pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
355    order_key_indices: &[usize],
356    degree_state: &mut TableInner<S>,
357    matched_row: &mut JoinRow<OwnedRow>,
358) {
359    let old_degree_row = matched_row
360        .row
361        .as_ref()
362        .project(order_key_indices)
363        .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
364    if INCREMENT {
365        matched_row.degree += 1;
366    } else {
367        // DECREMENT
368        matched_row.degree -= 1;
369    }
370    let new_degree_row = matched_row
371        .row
372        .as_ref()
373        .project(order_key_indices)
374        .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
375    degree_state.table.update(old_degree_row, new_degree_row);
376}
377
378pub struct TableInner<S: StateStore> {
379    /// Indices of the (cache) pk in a state row
380    pk_indices: Vec<usize>,
381    /// Indices of the join key in a state row
382    join_key_indices: Vec<usize>,
383    /// The order key of the join side has the following format:
384    /// | `join_key` ... | pk ... |
385    /// Where `join_key` contains all the columns not in the pk.
386    /// It should be a superset of the pk.
387    order_key_indices: Vec<usize>,
388    pub(crate) table: StateTable<S>,
389}
390
391impl<S: StateStore> TableInner<S> {
392    pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
393        let order_key_indices = table.pk_indices().to_vec();
394        Self {
395            pk_indices,
396            join_key_indices,
397            order_key_indices,
398            table,
399        }
400    }
401
402    fn error_context(&self, row: &impl Row) -> String {
403        let pk = row.project(&self.pk_indices);
404        let jk = row.project(&self.join_key_indices);
405        format!(
406            "join key: {}, pk: {}, row: {}, state_table_id: {}",
407            jk.display(),
408            pk.display(),
409            row.display(),
410            self.table.table_id()
411        )
412    }
413}
414
415impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
416    /// Create a [`JoinHashMap`] with the given LRU capacity.
417    #[allow(clippy::too_many_arguments)]
418    pub fn new(
419        watermark_sequence: AtomicU64Ref,
420        join_key_data_types: Vec<DataType>,
421        state_join_key_indices: Vec<usize>,
422        state_all_data_types: Vec<DataType>,
423        state_table: StateTable<S>,
424        state_pk_indices: Vec<usize>,
425        degree_state: Option<TableInner<S>>,
426        null_matched: K::Bitmap,
427        pk_contained_in_jk: bool,
428        inequality_key_idx: Option<usize>,
429        metrics: Arc<StreamingMetrics>,
430        actor_id: ActorId,
431        fragment_id: FragmentId,
432        side: &'static str,
433    ) -> Self {
434        let alloc = StatsAlloc::new(Global).shared();
435        // TODO: unify pk encoding with state table.
436        let pk_data_types = state_pk_indices
437            .iter()
438            .map(|i| state_all_data_types[*i].clone())
439            .collect();
440        let pk_serializer = OrderedRowSerde::new(
441            pk_data_types,
442            vec![OrderType::ascending(); state_pk_indices.len()],
443        );
444
445        let inequality_key_desc = inequality_key_idx.map(|idx| {
446            let serializer = OrderedRowSerde::new(
447                vec![state_all_data_types[idx].clone()],
448                vec![OrderType::ascending()],
449            );
450            InequalityKeyDesc { idx, serializer }
451        });
452
453        let join_table_id = state_table.table_id();
454        let state = TableInner {
455            pk_indices: state_pk_indices,
456            join_key_indices: state_join_key_indices,
457            order_key_indices: state_table.pk_indices().to_vec(),
458            table: state_table,
459        };
460
461        let need_degree_table = degree_state.is_some();
462
463        let metrics_info = MetricsInfo::new(
464            metrics.clone(),
465            join_table_id,
466            actor_id,
467            format!("hash join {}", side),
468        );
469
470        let cache = ManagedLruCache::unbounded_with_hasher_in(
471            watermark_sequence,
472            metrics_info,
473            PrecomputedBuildHasher,
474            alloc,
475        );
476
477        Self {
478            inner: cache,
479            join_key_data_types,
480            null_matched,
481            pk_serializer,
482            state,
483            degree_state,
484            need_degree_table,
485            pk_contained_in_jk,
486            inequality_key_desc,
487            metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
488            _marker: std::marker::PhantomData,
489        }
490    }
491
492    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
493        self.state.table.init_epoch(epoch).await?;
494        if let Some(degree_state) = &mut self.degree_state {
495            degree_state.table.init_epoch(epoch).await?;
496        }
497        Ok(())
498    }
499}
500
501impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
502    pub async fn post_yield_barrier(
503        self,
504        vnode_bitmap: Option<Arc<Bitmap>>,
505    ) -> StreamExecutorResult<Option<bool>> {
506        let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
507        if let Some(degree_state) = self.degree_state {
508            let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
509        }
510        let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
511        if cache_may_stale.unwrap_or(false) {
512            self.inner.clear();
513        }
514        Ok(cache_may_stale)
515    }
516}
517impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
518    pub fn update_watermark(&mut self, watermark: ScalarImpl) {
519        // TODO: remove data in cache.
520        self.state.table.update_watermark(watermark.clone());
521        if let Some(degree_state) = &mut self.degree_state {
522            degree_state.table.update_watermark(watermark);
523        }
524    }
525
526    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
527    /// `update_state` after some operations to put the state back.
528    ///
529    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
530    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
531    /// returned.
532    ///
533    /// Note: This will NOT remove anything from remote storage.
534    pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
535        self.metrics.total_lookup_count += 1;
536        if self.inner.contains(key) {
537            tracing::trace!("hit cache for join key: {:?}", key);
538            // Do not update the LRU statistics here with `peek_mut` since we will put the state
539            // back.
540            let mut state = self.inner.peek_mut(key).expect("checked contains");
541            CacheResult::Hit(state.take())
542        } else {
543            tracing::trace!("miss cache for join key: {:?}", key);
544            CacheResult::Miss
545        }
546    }
547
548    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
549    /// `update_state` after some operations to put the state back.
550    ///
551    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
552    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
553    /// returned.
554    ///
555    /// Note: This will NOT remove anything from remote storage.
556    pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
557        self.metrics.total_lookup_count += 1;
558        let state = if self.inner.contains(key) {
559            // Do not update the LRU statistics here with `peek_mut` since we will put the state
560            // back.
561            let mut state = self.inner.peek_mut(key).unwrap();
562            state.take()
563        } else {
564            self.metrics.lookup_miss_count += 1;
565            self.fetch_cached_state(key).await?.into()
566        };
567        Ok(state)
568    }
569
570    /// Fetch cache from the state store. Should only be called if the key does not exist in memory.
571    /// Will return a empty `JoinEntryState` even when state does not exist in remote.
572    async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
573        let key = key.deserialize(&self.join_key_data_types)?;
574
575        let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
576
577        if self.need_degree_table {
578            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
579                &(Bound::Unbounded, Bound::Unbounded);
580            let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
581                &key,
582                sub_range,
583                PrefetchOptions::default(),
584            );
585            let degree_state = self.degree_state.as_ref().unwrap();
586            let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
587                &key,
588                sub_range,
589                PrefetchOptions::default(),
590            );
591
592            let (table_iter, degree_table_iter) =
593                try_join(table_iter_fut, degree_table_iter_fut).await?;
594
595            let mut pinned_table_iter = std::pin::pin!(table_iter);
596            let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
597
598            // For better tolerating inconsistent stream, we have to first buffer all rows and
599            // degree rows, and check the number of them, then iterate on them.
600            let mut rows = vec![];
601            let mut degree_rows = vec![];
602            let mut inconsistency_happened = false;
603            loop {
604                let (row, degree_row) =
605                    join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
606                let (row, degree_row) = match (row, degree_row) {
607                    (None, None) => break,
608                    (None, Some(_)) => {
609                        inconsistency_happened = true;
610                        consistency_panic!(
611                            "mismatched row and degree table of join key: {:?}, degree table has more rows",
612                            &key
613                        );
614                        break;
615                    }
616                    (Some(_), None) => {
617                        inconsistency_happened = true;
618                        consistency_panic!(
619                            "mismatched row and degree table of join key: {:?}, input table has more rows",
620                            &key
621                        );
622                        break;
623                    }
624                    (Some(r), Some(d)) => (r, d),
625                };
626
627                let row = row?;
628                let degree_row = degree_row?;
629                rows.push(row);
630                degree_rows.push(degree_row);
631            }
632
633            if inconsistency_happened {
634                // Pk-based row-degree pairing.
635                assert_ne!(rows.len(), degree_rows.len());
636
637                let row_iter = stream::iter(rows.into_iter()).peekable();
638                let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
639                pin_mut!(row_iter);
640                pin_mut!(degree_row_iter);
641
642                loop {
643                    match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
644                        (None, _) | (_, None) => break,
645                        (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
646                            Ordering::Greater => {
647                                degree_row_iter.next().await;
648                            }
649                            Ordering::Less => {
650                                row_iter.next().await;
651                            }
652                            Ordering::Equal => {
653                                let row =
654                                    row_iter.next().await.expect("we matched some(row) above");
655                                let degree_row = degree_row_iter
656                                    .next()
657                                    .await
658                                    .expect("we matched some(degree_row) above");
659                                let pk = row
660                                    .as_ref()
661                                    .project(&self.state.pk_indices)
662                                    .memcmp_serialize(&self.pk_serializer);
663                                let degree_i64 = degree_row
664                                    .datum_at(degree_row.len() - 1)
665                                    .expect("degree should not be NULL");
666                                let inequality_key = self
667                                    .inequality_key_desc
668                                    .as_ref()
669                                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
670                                entry_state
671                                    .insert(
672                                        pk,
673                                        E::encode(&JoinRow::new(
674                                            row.row(),
675                                            degree_i64.into_int64() as u64,
676                                        )),
677                                        inequality_key,
678                                    )
679                                    .with_context(|| self.state.error_context(row.row()))?;
680                            }
681                        },
682                    }
683                }
684            } else {
685                // 1 to 1 row-degree pairing.
686                // Actually it's possible that both the input data table and the degree table missed
687                // some equal number of rows, but let's ignore this case because it should be rare.
688
689                assert_eq!(rows.len(), degree_rows.len());
690
691                #[for_await]
692                for (row, degree_row) in
693                    stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
694                {
695                    let pk1 = row.key();
696                    let pk2 = degree_row.key();
697                    debug_assert_eq!(
698                        pk1, pk2,
699                        "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
700                    );
701                    let pk = row
702                        .as_ref()
703                        .project(&self.state.pk_indices)
704                        .memcmp_serialize(&self.pk_serializer);
705                    let inequality_key = self
706                        .inequality_key_desc
707                        .as_ref()
708                        .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
709                    let degree_i64 = degree_row
710                        .datum_at(degree_row.len() - 1)
711                        .expect("degree should not be NULL");
712                    entry_state
713                        .insert(
714                            pk,
715                            E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
716                            inequality_key,
717                        )
718                        .with_context(|| self.state.error_context(row.row()))?;
719                }
720            }
721        } else {
722            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
723                &(Bound::Unbounded, Bound::Unbounded);
724            let table_iter = self
725                .state
726                .table
727                .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
728                .await?;
729
730            #[for_await]
731            for entry in table_iter {
732                let row = entry?;
733                let pk = row
734                    .as_ref()
735                    .project(&self.state.pk_indices)
736                    .memcmp_serialize(&self.pk_serializer);
737                let inequality_key = self
738                    .inequality_key_desc
739                    .as_ref()
740                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
741                entry_state
742                    .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
743                    .with_context(|| self.state.error_context(row.row()))?;
744            }
745        };
746
747        Ok(entry_state)
748    }
749
750    pub async fn flush(
751        &mut self,
752        epoch: EpochPair,
753    ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
754        self.metrics.flush();
755        let state_post_commit = self.state.table.commit(epoch).await?;
756        let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
757            Some(degree_state.table.commit(epoch).await?)
758        } else {
759            None
760        };
761        Ok(JoinHashMapPostCommit {
762            state: state_post_commit,
763            degree_state: degree_state_post_commit,
764            inner: &mut self.inner,
765        })
766    }
767
768    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
769        self.state.table.try_flush().await?;
770        if let Some(degree_state) = &mut self.degree_state {
771            degree_state.table.try_flush().await?;
772        }
773        Ok(())
774    }
775
776    pub fn insert_handle_degree(
777        &mut self,
778        key: &K,
779        value: JoinRow<impl Row>,
780    ) -> StreamExecutorResult<()> {
781        if self.need_degree_table {
782            self.insert(key, value)
783        } else {
784            self.insert_row(key, value.row)
785        }
786    }
787
788    /// Insert a join row
789    pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
790        let pk = self.serialize_pk_from_row(&value.row);
791
792        let inequality_key = self
793            .inequality_key_desc
794            .as_ref()
795            .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
796
797        // TODO(yuhao): avoid this `contains`.
798        // https://github.com/risingwavelabs/risingwave/issues/9233
799        if self.inner.contains(key) {
800            // Update cache
801            let mut entry = self.inner.get_mut(key).expect("checked contains");
802            entry
803                .insert(pk, E::encode(&value), inequality_key)
804                .with_context(|| self.state.error_context(&value.row))?;
805        } else if self.pk_contained_in_jk {
806            // Refill cache when the join key exist in neither cache or storage.
807            self.metrics.insert_cache_miss_count += 1;
808            let mut entry: JoinEntryState<E> = JoinEntryState::default();
809            entry
810                .insert(pk, E::encode(&value), inequality_key)
811                .with_context(|| self.state.error_context(&value.row))?;
812            self.update_state(key, entry.into());
813        }
814
815        // Update the flush buffer.
816        if let Some(degree_state) = self.degree_state.as_mut() {
817            let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
818            self.state.table.insert(row);
819            degree_state.table.insert(degree);
820        } else {
821            self.state.table.insert(value.row);
822        }
823        Ok(())
824    }
825
826    /// Insert a row.
827    /// Used when the side does not need to update degree.
828    pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
829        let join_row = JoinRow::new(&value, 0);
830        self.insert(key, join_row)?;
831        Ok(())
832    }
833
834    pub fn delete_handle_degree(
835        &mut self,
836        key: &K,
837        value: JoinRow<impl Row>,
838    ) -> StreamExecutorResult<()> {
839        if self.need_degree_table {
840            self.delete(key, value)
841        } else {
842            self.delete_row(key, value.row)
843        }
844    }
845
846    /// Delete a join row
847    pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
848        if let Some(mut entry) = self.inner.get_mut(key) {
849            let pk = (&value.row)
850                .project(&self.state.pk_indices)
851                .memcmp_serialize(&self.pk_serializer);
852            let inequality_key = self
853                .inequality_key_desc
854                .as_ref()
855                .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
856            entry
857                .remove(pk, inequality_key.as_ref())
858                .with_context(|| self.state.error_context(&value.row))?;
859        }
860
861        // If no cache maintained, only update the state table.
862        let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
863        self.state.table.delete(row);
864        let degree_state = self.degree_state.as_mut().expect("degree table missing");
865        degree_state.table.delete(degree);
866        Ok(())
867    }
868
869    /// Delete a row
870    /// Used when the side does not need to update degree.
871    pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
872        if let Some(mut entry) = self.inner.get_mut(key) {
873            let pk = (&value)
874                .project(&self.state.pk_indices)
875                .memcmp_serialize(&self.pk_serializer);
876
877            let inequality_key = self
878                .inequality_key_desc
879                .as_ref()
880                .map(|desc| desc.serialize_inequal_key_from_row(&value));
881            entry
882                .remove(pk, inequality_key.as_ref())
883                .with_context(|| self.state.error_context(&value))?;
884        }
885
886        // If no cache maintained, only update the state table.
887        self.state.table.delete(value);
888        Ok(())
889    }
890
891    /// Update a [`JoinEntryState`] into the hash table.
892    pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
893        self.inner.put(key.clone(), HashValueWrapper(Some(state)));
894    }
895
896    /// Evict the cache.
897    pub fn evict(&mut self) {
898        self.inner.evict();
899    }
900
901    /// Cached entry count for this hash table.
902    pub fn entry_count(&self) -> usize {
903        self.inner.len()
904    }
905
906    pub fn null_matched(&self) -> &K::Bitmap {
907        &self.null_matched
908    }
909
910    pub fn table_id(&self) -> u32 {
911        self.state.table.table_id()
912    }
913
914    pub fn join_key_data_types(&self) -> &[DataType] {
915        &self.join_key_data_types
916    }
917
918    /// Return true if the inequality key is null.
919    /// # Panics
920    /// Panics if the inequality key is not set.
921    pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
922        let desc = self
923            .inequality_key_desc
924            .as_ref()
925            .expect("inequality key desc missing");
926        row.datum_at(desc.idx).is_none()
927    }
928
929    /// Serialize the inequality key from a row.
930    /// # Panics
931    /// Panics if the inequality key is not set.
932    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
933        self.inequality_key_desc
934            .as_ref()
935            .expect("inequality key desc missing")
936            .serialize_inequal_key_from_row(&row)
937    }
938
939    pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
940        row.project(&self.state.pk_indices)
941            .memcmp_serialize(&self.pk_serializer)
942    }
943}
944
945#[must_use]
946pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
947    state: StateTablePostCommit<'a, S>,
948    degree_state: Option<StateTablePostCommit<'a, S>>,
949    inner: &'a mut JoinHashMapInner<K, E>,
950}
951
952use risingwave_common_estimate_size::KvSize;
953use thiserror::Error;
954
955use super::*;
956use crate::executor::prelude::{Stream, try_stream};
957
958/// We manages a `HashMap` in memory for all entries belonging to a join key.
959/// When evicted, `cached` does not hold any entries.
960///
961/// If a `JoinEntryState` exists for a join key, the all records under this
962/// join key will be presented in the cache.
963#[derive(Default)]
964pub struct JoinEntryState<E: JoinEncoding> {
965    /// The full copy of the state.
966    cached: JoinRowSet<PkType, E::EncodedRow>,
967    /// Index used for AS OF join. The key is inequal column value. The value is the primary key in `cached`.
968    inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
969    kv_heap_size: KvSize,
970}
971
972impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
973    fn estimated_heap_size(&self) -> usize {
974        // TODO: Add btreemap internal size.
975        // https://github.com/risingwavelabs/risingwave/issues/9713
976        self.kv_heap_size.size()
977    }
978}
979
980#[derive(Error, Debug)]
981pub enum JoinEntryError {
982    #[error("double inserting a join state entry")]
983    Occupied,
984    #[error("removing a join state entry but it is not in the cache")]
985    Remove,
986    #[error("retrieving a pk from the inequality index but it is not in the cache")]
987    InequalIndex,
988}
989
990impl<E: JoinEncoding> JoinEntryState<E> {
991    /// Insert into the cache.
992    pub fn insert(
993        &mut self,
994        key: PkType,
995        value: E::EncodedRow,
996        inequality_key: Option<InequalKeyType>,
997    ) -> Result<&mut E::EncodedRow, JoinEntryError> {
998        let mut removed = false;
999        if !enable_strict_consistency() {
1000            // strict consistency is off, let's remove existing (if any) first
1001            if let Some(old_value) = self.cached.remove(&key) {
1002                if let Some(inequality_key) = inequality_key.as_ref() {
1003                    self.remove_pk_from_inequality_index(&key, inequality_key);
1004                }
1005                self.kv_heap_size.sub(&key, &old_value);
1006                removed = true;
1007            }
1008        }
1009
1010        self.kv_heap_size.add(&key, &value);
1011
1012        if let Some(inequality_key) = inequality_key {
1013            self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1014        }
1015        let ret = self.cached.try_insert(key.clone(), value);
1016
1017        if !enable_strict_consistency() {
1018            assert!(ret.is_ok(), "we have removed existing entry, if any");
1019            if removed {
1020                // if not silent, we should log the error
1021                consistency_error!(?key, "double inserting a join state entry");
1022            }
1023        }
1024
1025        ret.map_err(|_| JoinEntryError::Occupied)
1026    }
1027
1028    /// Delete from the cache.
1029    pub fn remove(
1030        &mut self,
1031        pk: PkType,
1032        inequality_key: Option<&InequalKeyType>,
1033    ) -> Result<(), JoinEntryError> {
1034        if let Some(value) = self.cached.remove(&pk) {
1035            self.kv_heap_size.sub(&pk, &value);
1036            if let Some(inequality_key) = inequality_key {
1037                self.remove_pk_from_inequality_index(&pk, inequality_key);
1038            }
1039            Ok(())
1040        } else if enable_strict_consistency() {
1041            Err(JoinEntryError::Remove)
1042        } else {
1043            consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1044            Ok(())
1045        }
1046    }
1047
1048    fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1049        if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1050            if pk_set.remove(pk).is_none() {
1051                if enable_strict_consistency() {
1052                    panic!("removing a pk that it not in the inequality index");
1053                } else {
1054                    consistency_error!(?pk, "removing a pk that it not in the inequality index");
1055                };
1056            } else {
1057                self.kv_heap_size.sub(pk, &());
1058            }
1059            if pk_set.is_empty() {
1060                self.inequality_index.remove(inequality_key);
1061            }
1062        }
1063    }
1064
1065    fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1066        if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1067            let pk_size = pk.estimated_size();
1068            if pk_set.try_insert(pk, ()).is_err() {
1069                if enable_strict_consistency() {
1070                    panic!("inserting a pk that it already in the inequality index");
1071                } else {
1072                    consistency_error!("inserting a pk that it already in the inequality index");
1073                };
1074            } else {
1075                self.kv_heap_size.add_size(pk_size);
1076            }
1077        } else {
1078            let mut pk_set = JoinRowSet::default();
1079            pk_set.try_insert(pk, ()).expect("pk set should be empty");
1080            self.inequality_index
1081                .try_insert(inequality_key, pk_set)
1082                .expect("pk set should be empty");
1083        }
1084    }
1085
1086    pub fn get(
1087        &self,
1088        pk: &PkType,
1089        data_types: &[DataType],
1090    ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1091        self.cached
1092            .get(pk)
1093            .map(|encoded| encoded.decode(data_types))
1094    }
1095
1096    /// Note: the first item in the tuple is the mutable reference to the value in this entry, while
1097    /// the second item is the decoded value. To mutate the degree, one **must not** forget to apply
1098    /// the changes to the first item.
1099    ///
1100    /// WARNING: Should not change the heap size of `StateValueType` with the mutable reference.
1101    pub fn values_mut<'a>(
1102        &'a mut self,
1103        data_types: &'a [DataType],
1104    ) -> impl Iterator<
1105        Item = (
1106            &'a mut E::EncodedRow,
1107            StreamExecutorResult<JoinRow<OwnedRow>>,
1108        ),
1109    > + 'a {
1110        self.cached.values_mut().map(|encoded| {
1111            let decoded = encoded.decode(data_types);
1112            (encoded, decoded)
1113        })
1114    }
1115
1116    pub fn len(&self) -> usize {
1117        self.cached.len()
1118    }
1119
1120    /// Range scan the cache using the inequality index.
1121    pub fn range_by_inequality<'a, R>(
1122        &'a self,
1123        range: R,
1124        data_types: &'a [DataType],
1125    ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<OwnedRow>>> + 'a
1126    where
1127        R: RangeBounds<InequalKeyType> + 'a,
1128    {
1129        self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1130            pk_set
1131                .keys()
1132                .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1133        })
1134    }
1135
1136    /// Get the records whose inequality key upper bound satisfy the given bound.
1137    pub fn upper_bound_by_inequality<'a>(
1138        &'a self,
1139        bound: Bound<&InequalKeyType>,
1140        data_types: &'a [DataType],
1141    ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1142        if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1143            if let Some(pk) = pk_set.first_key_sorted() {
1144                self.get_by_indexed_pk(pk, data_types)
1145            } else {
1146                panic!("pk set for a index record must has at least one element");
1147            }
1148        } else {
1149            None
1150        }
1151    }
1152
1153    pub fn get_by_indexed_pk(
1154        &self,
1155        pk: &PkType,
1156        data_types: &[DataType],
1157    ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>>
1158where {
1159        if let Some(value) = self.cached.get(pk) {
1160            Some(value.decode(data_types))
1161        } else if enable_strict_consistency() {
1162            Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1163        } else {
1164            consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1165            None
1166        }
1167    }
1168
1169    /// Get the records whose inequality key lower bound satisfy the given bound.
1170    pub fn lower_bound_by_inequality<'a>(
1171        &'a self,
1172        bound: Bound<&InequalKeyType>,
1173        data_types: &'a [DataType],
1174    ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1175        if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1176            if let Some(pk) = pk_set.first_key_sorted() {
1177                self.get_by_indexed_pk(pk, data_types)
1178            } else {
1179                panic!("pk set for a index record must has at least one element");
1180            }
1181        } else {
1182            None
1183        }
1184    }
1185
1186    pub fn get_first_by_inequality<'a>(
1187        &'a self,
1188        inequality_key: &InequalKeyType,
1189        data_types: &'a [DataType],
1190    ) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
1191        if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1192            if let Some(pk) = pk_set.first_key_sorted() {
1193                self.get_by_indexed_pk(pk, data_types)
1194            } else {
1195                panic!("pk set for a index record must has at least one element");
1196            }
1197        } else {
1198            None
1199        }
1200    }
1201
1202    pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1203        &self.inequality_index
1204    }
1205}
1206
1207#[cfg(test)]
1208mod tests {
1209    use itertools::Itertools;
1210    use risingwave_common::array::*;
1211    use risingwave_common::util::iter_util::ZipEqDebug;
1212
1213    use super::*;
1214    use crate::executor::MemoryEncoding;
1215
1216    fn insert_chunk<E: JoinEncoding>(
1217        managed_state: &mut JoinEntryState<E>,
1218        pk_indices: &[usize],
1219        col_types: &[DataType],
1220        inequality_key_idx: Option<usize>,
1221        data_chunk: &DataChunk,
1222    ) {
1223        let pk_col_type = pk_indices
1224            .iter()
1225            .map(|idx| col_types[*idx].clone())
1226            .collect_vec();
1227        let pk_serializer =
1228            OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1229        let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1230        let inequality_key_serializer = inequality_key_type
1231            .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1232        for row_ref in data_chunk.rows() {
1233            let row: OwnedRow = row_ref.into_owned_row();
1234            let value_indices = (0..row.len() - 1).collect_vec();
1235            let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1236            // Pk is only a `i64` here, so encoding method does not matter.
1237            let pk = OwnedRow::new(pk)
1238                .project(&value_indices)
1239                .memcmp_serialize(&pk_serializer);
1240            let inequality_key = inequality_key_idx.map(|idx| {
1241                (&row)
1242                    .project(&[idx])
1243                    .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1244            });
1245            let join_row = JoinRow { row, degree: 0 };
1246            managed_state
1247                .insert(pk, E::encode(&join_row), inequality_key)
1248                .unwrap();
1249        }
1250    }
1251
1252    fn check<E: JoinEncoding>(
1253        managed_state: &mut JoinEntryState<E>,
1254        col_types: &[DataType],
1255        col1: &[i64],
1256        col2: &[i64],
1257    ) {
1258        for ((_, matched_row), (d1, d2)) in managed_state
1259            .values_mut(col_types)
1260            .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1261        {
1262            let matched_row = matched_row.unwrap();
1263            assert_eq!(matched_row.row[0], Some(ScalarImpl::Int64(*d1)));
1264            assert_eq!(matched_row.row[1], Some(ScalarImpl::Int64(*d2)));
1265            assert_eq!(matched_row.degree, 0);
1266        }
1267    }
1268
1269    #[tokio::test]
1270    async fn test_managed_join_state() {
1271        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1272        let col_types = vec![DataType::Int64, DataType::Int64];
1273        let pk_indices = [0];
1274
1275        let col1 = [3, 2, 1];
1276        let col2 = [4, 5, 6];
1277        let data_chunk1 = DataChunk::from_pretty(
1278            "I I
1279             3 4
1280             2 5
1281             1 6",
1282        );
1283
1284        // `Vec` in state
1285        insert_chunk::<MemoryEncoding>(
1286            &mut managed_state,
1287            &pk_indices,
1288            &col_types,
1289            None,
1290            &data_chunk1,
1291        );
1292        check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1293
1294        // `BtreeMap` in state
1295        let col1 = [1, 2, 3, 4, 5];
1296        let col2 = [6, 5, 4, 9, 8];
1297        let data_chunk2 = DataChunk::from_pretty(
1298            "I I
1299             5 8
1300             4 9",
1301        );
1302        insert_chunk(
1303            &mut managed_state,
1304            &pk_indices,
1305            &col_types,
1306            None,
1307            &data_chunk2,
1308        );
1309        check(&mut managed_state, &col_types, &col1, &col2);
1310    }
1311
1312    #[tokio::test]
1313    async fn test_managed_join_state_w_inequality_index() {
1314        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1315        let col_types = vec![DataType::Int64, DataType::Int64];
1316        let pk_indices = [0];
1317        let inequality_key_idx = Some(1);
1318        let inequality_key_serializer =
1319            OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1320
1321        let col1 = [3, 2, 1];
1322        let col2 = [4, 5, 5];
1323        let data_chunk1 = DataChunk::from_pretty(
1324            "I I
1325             3 4
1326             2 5
1327             1 5",
1328        );
1329
1330        // `Vec` in state
1331        insert_chunk(
1332            &mut managed_state,
1333            &pk_indices,
1334            &col_types,
1335            inequality_key_idx,
1336            &data_chunk1,
1337        );
1338        check(&mut managed_state, &col_types, &col1, &col2);
1339        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1340            .memcmp_serialize(&inequality_key_serializer);
1341        let row = managed_state
1342            .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1343            .unwrap()
1344            .unwrap();
1345        assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1346        let row = managed_state
1347            .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1348            .unwrap()
1349            .unwrap();
1350        assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1351
1352        // `BtreeMap` in state
1353        let col1 = [1, 2, 3, 4, 5];
1354        let col2 = [5, 5, 4, 4, 8];
1355        let data_chunk2 = DataChunk::from_pretty(
1356            "I I
1357             5 8
1358             4 4",
1359        );
1360        insert_chunk(
1361            &mut managed_state,
1362            &pk_indices,
1363            &col_types,
1364            inequality_key_idx,
1365            &data_chunk2,
1366        );
1367        check(&mut managed_state, &col_types, &col1, &col2);
1368
1369        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1370            .memcmp_serialize(&inequality_key_serializer);
1371        let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1372        assert!(row.is_none());
1373    }
1374}