risingwave_stream/executor/join/
hash_join.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use risingwave_common::bitmap::Bitmap;
25use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
26use risingwave_common::metrics::LabelGuardedIntCounter;
27use risingwave_common::row::{OwnedRow, Row, RowExt, once};
28use risingwave_common::types::{DataType, ScalarImpl};
29use risingwave_common::util::epoch::EpochPair;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::row_serde::OrderedRowSerde;
32use risingwave_common::util::sort_util::OrderType;
33use risingwave_common_estimate_size::EstimateSize;
34use risingwave_storage::StateStore;
35use risingwave_storage::store::PrefetchOptions;
36use risingwave_storage::table::KeyedRow;
37use thiserror_ext::AsReport;
38
39use super::row::{CachedJoinRow, DegreeType};
40use crate::cache::ManagedLruCache;
41use crate::common::metrics::MetricsInfo;
42use crate::common::table::state_table::{StateTable, StateTablePostCommit};
43use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
44use crate::executor::error::StreamExecutorResult;
45use crate::executor::join::row::JoinRow;
46use crate::executor::monitor::StreamingMetrics;
47use crate::executor::{JoinEncoding, StreamExecutorError};
48use crate::task::{ActorId, AtomicU64Ref, FragmentId};
49
50/// Memcomparable encoding.
51type PkType = Vec<u8>;
52type InequalKeyType = Vec<u8>;
53
54pub type HashValueType<E> = Box<JoinEntryState<E>>;
55
56impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
57    fn estimated_heap_size(&self) -> usize {
58        self.as_ref().estimated_heap_size()
59    }
60}
61
62/// The wrapper for [`JoinEntryState`] which should be `Some` most of the time in the hash table.
63///
64/// When the executor is operating on the specific entry of the map, it can hold the ownership of
65/// the entry by taking the value out of the `Option`, instead of holding a mutable reference to the
66/// map, which can make the compiler happy.
67struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
68
69pub(crate) enum CacheResult<E: JoinEncoding> {
70    NeverMatch,            // Will never match, will not be in cache at all.
71    Miss,                  // Cache-miss
72    Hit(HashValueType<E>), // Cache-hit
73}
74
75impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
76    fn estimated_heap_size(&self) -> usize {
77        self.0.estimated_heap_size()
78    }
79}
80
81impl<E: JoinEncoding> HashValueWrapper<E> {
82    const MESSAGE: &'static str = "the state should always be `Some`";
83
84    /// Take the value out of the wrapper. Panic if the value is `None`.
85    pub fn take(&mut self) -> HashValueType<E> {
86        self.0.take().expect(Self::MESSAGE)
87    }
88}
89
90impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
91    type Target = HashValueType<E>;
92
93    fn deref(&self) -> &Self::Target {
94        self.0.as_ref().expect(Self::MESSAGE)
95    }
96}
97
98impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
99    fn deref_mut(&mut self) -> &mut Self::Target {
100        self.0.as_mut().expect(Self::MESSAGE)
101    }
102}
103
104type JoinHashMapInner<K, E> = ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher>;
105
106pub struct JoinHashMapMetrics {
107    /// Basic information
108    /// How many times have we hit the cache of join executor
109    lookup_miss_count: usize,
110    total_lookup_count: usize,
111    /// How many times have we miss the cache when insert row
112    insert_cache_miss_count: usize,
113
114    // Metrics
115    join_lookup_total_count_metric: LabelGuardedIntCounter,
116    join_lookup_miss_count_metric: LabelGuardedIntCounter,
117    join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
118}
119
120impl JoinHashMapMetrics {
121    pub fn new(
122        metrics: &StreamingMetrics,
123        actor_id: ActorId,
124        fragment_id: FragmentId,
125        side: &'static str,
126        join_table_id: TableId,
127    ) -> Self {
128        let actor_id = actor_id.to_string();
129        let fragment_id = fragment_id.to_string();
130        let join_table_id = join_table_id.to_string();
131        let join_lookup_total_count_metric = metrics
132            .join_lookup_total_count
133            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
134        let join_lookup_miss_count_metric = metrics
135            .join_lookup_miss_count
136            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
137        let join_insert_cache_miss_count_metrics = metrics
138            .join_insert_cache_miss_count
139            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
140
141        Self {
142            lookup_miss_count: 0,
143            total_lookup_count: 0,
144            insert_cache_miss_count: 0,
145            join_lookup_total_count_metric,
146            join_lookup_miss_count_metric,
147            join_insert_cache_miss_count_metrics,
148        }
149    }
150
151    pub fn flush(&mut self) {
152        self.join_lookup_total_count_metric
153            .inc_by(self.total_lookup_count as u64);
154        self.join_lookup_miss_count_metric
155            .inc_by(self.lookup_miss_count as u64);
156        self.join_insert_cache_miss_count_metrics
157            .inc_by(self.insert_cache_miss_count as u64);
158        self.total_lookup_count = 0;
159        self.lookup_miss_count = 0;
160        self.insert_cache_miss_count = 0;
161    }
162}
163
164/// Inequality key description for `AsOf` join.
165struct InequalityKeyDesc {
166    idx: usize,
167    serializer: OrderedRowSerde,
168}
169
170impl InequalityKeyDesc {
171    /// Serialize the inequality key from a row.
172    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
173        let indices = vec![self.idx];
174        let inequality_key = row.project(&indices);
175        inequality_key.memcmp_serialize(&self.serializer)
176    }
177}
178
179pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
180    /// Store the join states.
181    inner: JoinHashMapInner<K, E>,
182    /// Data types of the join key columns
183    join_key_data_types: Vec<DataType>,
184    /// Null safe bitmap for each join pair
185    null_matched: K::Bitmap,
186    /// The memcomparable serializer of primary key.
187    pk_serializer: OrderedRowSerde,
188    /// State table. Contains the data from upstream.
189    state: TableInner<S>,
190    /// Degree table.
191    ///
192    /// The degree is generated from the hash join executor.
193    /// Each row in `state` has a corresponding degree in `degree state`.
194    /// A degree value `d` in for a row means the row has `d` matched row in the other join side.
195    ///
196    /// It will only be used when needed in a side.
197    ///
198    /// - Full Outer: both side
199    /// - Left Outer/Semi/Anti: left side
200    /// - Right Outer/Semi/Anti: right side
201    /// - Inner: neither side.
202    ///
203    /// Should be set to `None` if `need_degree_table` was set to `false`.
204    ///
205    /// The degree of each row will tell us if we need to emit `NULL` for the row.
206    /// For instance, given `lhs LEFT JOIN rhs`,
207    /// If the degree of a row in `lhs` is 0, it means the row does not have a match in `rhs`.
208    /// If the degree of a row in `lhs` is 2, it means the row has two matches in `rhs`.
209    /// Now, when emitting the result of the join, we need to emit `NULL` for the row in `lhs` if
210    /// the degree is 0.
211    ///
212    /// Why don't just use a boolean value instead of a degree count?
213    /// Consider the case where we delete a matched record from `rhs`.
214    /// Since we can delete a record,
215    /// there must have been a record in `rhs` that matched the record in `lhs`.
216    /// So this value is `true`.
217    /// But we don't know how many records are matched after removing this record,
218    /// since we only stored a boolean value rather than the count.
219    /// Hence we need to store the count of matched records.
220    degree_state: Option<TableInner<S>>,
221    // TODO(kwannoel): Make this `const` instead.
222    /// If degree table is need
223    need_degree_table: bool,
224    /// Pk is part of the join key.
225    pk_contained_in_jk: bool,
226    /// Inequality key description for `AsOf` join.
227    inequality_key_desc: Option<InequalityKeyDesc>,
228    /// Metrics of the hash map
229    metrics: JoinHashMapMetrics,
230    _marker: std::marker::PhantomData<E>,
231}
232
233impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
234    pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
235        (&self.state.order_key_indices, &mut self.degree_state)
236    }
237
238    /// NOTE(kwannoel): This allows us to concurrently stream records from the `state_table`,
239    /// and update the degree table, without using `unsafe` code.
240    ///
241    /// This is because we obtain separate references to separate parts of the `JoinHashMap`,
242    /// instead of reusing the same reference to `JoinHashMap` for concurrent read access to `state_table`,
243    /// and write access to the degree table.
244    pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
245        &'a mut self,
246        key: &'a K,
247    ) -> StreamExecutorResult<(
248        impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
249        &'a [usize],
250        &'a mut Option<TableInner<S>>,
251    )> {
252        let degree_state = &mut self.degree_state;
253        let (order_key_indices, pk_indices, state_table) = (
254            &self.state.order_key_indices,
255            &self.state.pk_indices,
256            &mut self.state.table,
257        );
258        let degrees = if let Some(degree_state) = degree_state {
259            Some(fetch_degrees(key, &self.join_key_data_types, &degree_state.table).await?)
260        } else {
261            None
262        };
263        let stream = into_stream(
264            &self.join_key_data_types,
265            pk_indices,
266            &self.pk_serializer,
267            state_table,
268            key,
269            degrees,
270        );
271        Ok((stream, order_key_indices, &mut self.degree_state))
272    }
273}
274
275#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
276pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
277    join_key_data_types: &'a [DataType],
278    pk_indices: &'a [usize],
279    pk_serializer: &'a OrderedRowSerde,
280    state_table: &'a StateTable<S>,
281    key: &'a K,
282    degrees: Option<Vec<DegreeType>>,
283) {
284    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
285    let decoded_key = key.deserialize(join_key_data_types)?;
286    let table_iter = state_table
287        .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
288        .await?;
289
290    #[for_await]
291    for (i, entry) in table_iter.enumerate() {
292        let encoded_row = entry?;
293        let encoded_pk = encoded_row
294            .as_ref()
295            .project(pk_indices)
296            .memcmp_serialize(pk_serializer);
297        let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
298        yield (encoded_pk, join_row);
299    }
300}
301
302/// We use this to fetch ALL degrees into memory.
303/// We use this instead of a streaming interface.
304/// It is necessary because we must update the `degree_state_table` concurrently.
305/// If we obtain the degrees in a stream,
306/// we will need to hold an immutable reference to the state table for the entire lifetime,
307/// preventing us from concurrently updating the state table.
308///
309/// The cost of fetching all degrees upfront is acceptable. We currently already do so
310/// in `fetch_cached_state`.
311/// The memory use should be limited since we only store a u64.
312///
313/// Let's say we have amplification of 1B, we will have 1B * 8 bytes ~= 8GB
314///
315/// We can also have further optimization, to permit breaking the streaming update,
316/// to flush the in-memory degrees, if this is proven to have high memory consumption.
317///
318/// TODO(kwannoel): Perhaps we can cache these separately from matched rows too.
319/// Because matched rows may occupy a larger capacity.
320///
321/// Argument for this:
322/// We only hit this when cache miss. When cache miss, we will have this as one off cost.
323/// Keeping this cached separately from matched rows is beneficial.
324/// Then we can evict matched rows, without touching the degrees.
325async fn fetch_degrees<K: HashKey, S: StateStore>(
326    key: &K,
327    join_key_data_types: &[DataType],
328    degree_state_table: &StateTable<S>,
329) -> StreamExecutorResult<Vec<DegreeType>> {
330    let key = key.deserialize(join_key_data_types)?;
331    let mut degrees = vec![];
332    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
333    let table_iter = degree_state_table
334        .iter_with_prefix(key, sub_range, PrefetchOptions::default())
335        .await?;
336    #[for_await]
337    for entry in table_iter {
338        let degree_row = entry?;
339        let degree_i64 = degree_row
340            .datum_at(degree_row.len() - 1)
341            .expect("degree should not be NULL");
342        degrees.push(degree_i64.into_int64() as u64);
343    }
344    Ok(degrees)
345}
346
347// NOTE(kwannoel): This is not really specific to `TableInner`.
348// A degree table is `TableInner`, a `TableInner` might not be a degree table.
349// Hence we don't specify it in its impl block.
350pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
351    order_key_indices: &[usize],
352    degree_state: &mut TableInner<S>,
353    matched_row: &mut JoinRow<impl Row>,
354) {
355    let old_degree_row = (&matched_row.row)
356        .project(order_key_indices)
357        .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
358    if INCREMENT {
359        matched_row.degree += 1;
360    } else {
361        // DECREMENT
362        matched_row.degree -= 1;
363    }
364    let new_degree_row = (&matched_row.row)
365        .project(order_key_indices)
366        .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
367    degree_state.table.update(old_degree_row, new_degree_row);
368}
369
370pub struct TableInner<S: StateStore> {
371    /// Indices of the (cache) pk in a state row
372    pk_indices: Vec<usize>,
373    /// Indices of the join key in a state row
374    join_key_indices: Vec<usize>,
375    /// The order key of the join side has the following format:
376    /// | `join_key` ... | pk ... |
377    /// Where `join_key` contains all the columns not in the pk.
378    /// It should be a superset of the pk.
379    order_key_indices: Vec<usize>,
380    pub(crate) table: StateTable<S>,
381}
382
383impl<S: StateStore> TableInner<S> {
384    pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
385        let order_key_indices = table.pk_indices().to_vec();
386        Self {
387            pk_indices,
388            join_key_indices,
389            order_key_indices,
390            table,
391        }
392    }
393
394    fn error_context(&self, row: &impl Row) -> String {
395        let pk = row.project(&self.pk_indices);
396        let jk = row.project(&self.join_key_indices);
397        format!(
398            "join key: {}, pk: {}, row: {}, state_table_id: {}",
399            jk.display(),
400            pk.display(),
401            row.display(),
402            self.table.table_id()
403        )
404    }
405}
406
407impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
408    /// Create a [`JoinHashMap`] with the given LRU capacity.
409    #[allow(clippy::too_many_arguments)]
410    pub fn new(
411        watermark_sequence: AtomicU64Ref,
412        join_key_data_types: Vec<DataType>,
413        state_join_key_indices: Vec<usize>,
414        state_all_data_types: Vec<DataType>,
415        state_table: StateTable<S>,
416        state_pk_indices: Vec<usize>,
417        degree_state: Option<TableInner<S>>,
418        null_matched: K::Bitmap,
419        pk_contained_in_jk: bool,
420        inequality_key_idx: Option<usize>,
421        metrics: Arc<StreamingMetrics>,
422        actor_id: ActorId,
423        fragment_id: FragmentId,
424        side: &'static str,
425    ) -> Self {
426        // TODO: unify pk encoding with state table.
427        let pk_data_types = state_pk_indices
428            .iter()
429            .map(|i| state_all_data_types[*i].clone())
430            .collect();
431        let pk_serializer = OrderedRowSerde::new(
432            pk_data_types,
433            vec![OrderType::ascending(); state_pk_indices.len()],
434        );
435
436        let inequality_key_desc = inequality_key_idx.map(|idx| {
437            let serializer = OrderedRowSerde::new(
438                vec![state_all_data_types[idx].clone()],
439                vec![OrderType::ascending()],
440            );
441            InequalityKeyDesc { idx, serializer }
442        });
443
444        let join_table_id = state_table.table_id();
445        let state = TableInner {
446            pk_indices: state_pk_indices,
447            join_key_indices: state_join_key_indices,
448            order_key_indices: state_table.pk_indices().to_vec(),
449            table: state_table,
450        };
451
452        let need_degree_table = degree_state.is_some();
453
454        let metrics_info = MetricsInfo::new(
455            metrics.clone(),
456            join_table_id,
457            actor_id,
458            format!("hash join {}", side),
459        );
460
461        let cache = ManagedLruCache::unbounded_with_hasher(
462            watermark_sequence,
463            metrics_info,
464            PrecomputedBuildHasher,
465        );
466
467        Self {
468            inner: cache,
469            join_key_data_types,
470            null_matched,
471            pk_serializer,
472            state,
473            degree_state,
474            need_degree_table,
475            pk_contained_in_jk,
476            inequality_key_desc,
477            metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
478            _marker: std::marker::PhantomData,
479        }
480    }
481
482    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
483        self.state.table.init_epoch(epoch).await?;
484        if let Some(degree_state) = &mut self.degree_state {
485            degree_state.table.init_epoch(epoch).await?;
486        }
487        Ok(())
488    }
489}
490
491impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
492    pub async fn post_yield_barrier(
493        self,
494        vnode_bitmap: Option<Arc<Bitmap>>,
495    ) -> StreamExecutorResult<Option<bool>> {
496        let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
497        if let Some(degree_state) = self.degree_state {
498            let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
499        }
500        let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
501        if cache_may_stale.unwrap_or(false) {
502            self.inner.clear();
503        }
504        Ok(cache_may_stale)
505    }
506}
507impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
508    pub fn update_watermark(&mut self, watermark: ScalarImpl) {
509        // TODO: remove data in cache.
510        self.state.table.update_watermark(watermark.clone());
511        if let Some(degree_state) = &mut self.degree_state {
512            degree_state.table.update_watermark(watermark);
513        }
514    }
515
516    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
517    /// `update_state` after some operations to put the state back.
518    ///
519    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
520    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
521    /// returned.
522    ///
523    /// Note: This will NOT remove anything from remote storage.
524    pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
525        self.metrics.total_lookup_count += 1;
526        if self.inner.contains(key) {
527            tracing::trace!("hit cache for join key: {:?}", key);
528            // Do not update the LRU statistics here with `peek_mut` since we will put the state
529            // back.
530            let mut state = self.inner.peek_mut(key).expect("checked contains");
531            CacheResult::Hit(state.take())
532        } else {
533            self.metrics.lookup_miss_count += 1;
534            tracing::trace!("miss cache for join key: {:?}", key);
535            CacheResult::Miss
536        }
537    }
538
539    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
540    /// `update_state` after some operations to put the state back.
541    ///
542    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
543    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
544    /// returned.
545    ///
546    /// Note: This will NOT remove anything from remote storage.
547    pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
548        self.metrics.total_lookup_count += 1;
549        let state = if self.inner.contains(key) {
550            // Do not update the LRU statistics here with `peek_mut` since we will put the state
551            // back.
552            let mut state = self.inner.peek_mut(key).unwrap();
553            state.take()
554        } else {
555            self.metrics.lookup_miss_count += 1;
556            self.fetch_cached_state(key).await?.into()
557        };
558        Ok(state)
559    }
560
561    /// Fetch cache from the state store. Should only be called if the key does not exist in memory.
562    /// Will return a empty `JoinEntryState` even when state does not exist in remote.
563    async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
564        let key = key.deserialize(&self.join_key_data_types)?;
565
566        let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
567
568        if self.need_degree_table {
569            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
570                &(Bound::Unbounded, Bound::Unbounded);
571            let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
572                &key,
573                sub_range,
574                PrefetchOptions::default(),
575            );
576            let degree_state = self.degree_state.as_ref().unwrap();
577            let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
578                &key,
579                sub_range,
580                PrefetchOptions::default(),
581            );
582
583            let (table_iter, degree_table_iter) =
584                try_join(table_iter_fut, degree_table_iter_fut).await?;
585
586            let mut pinned_table_iter = std::pin::pin!(table_iter);
587            let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
588
589            // For better tolerating inconsistent stream, we have to first buffer all rows and
590            // degree rows, and check the number of them, then iterate on them.
591            let mut rows = vec![];
592            let mut degree_rows = vec![];
593            let mut inconsistency_happened = false;
594            loop {
595                let (row, degree_row) =
596                    join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
597                let (row, degree_row) = match (row, degree_row) {
598                    (None, None) => break,
599                    (None, Some(_)) => {
600                        inconsistency_happened = true;
601                        consistency_panic!(
602                            "mismatched row and degree table of join key: {:?}, degree table has more rows",
603                            &key
604                        );
605                        break;
606                    }
607                    (Some(_), None) => {
608                        inconsistency_happened = true;
609                        consistency_panic!(
610                            "mismatched row and degree table of join key: {:?}, input table has more rows",
611                            &key
612                        );
613                        break;
614                    }
615                    (Some(r), Some(d)) => (r, d),
616                };
617
618                let row = row?;
619                let degree_row = degree_row?;
620                rows.push(row);
621                degree_rows.push(degree_row);
622            }
623
624            if inconsistency_happened {
625                // Pk-based row-degree pairing.
626                assert_ne!(rows.len(), degree_rows.len());
627
628                let row_iter = stream::iter(rows.into_iter()).peekable();
629                let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
630                pin_mut!(row_iter);
631                pin_mut!(degree_row_iter);
632
633                loop {
634                    match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
635                        (None, _) | (_, None) => break,
636                        (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
637                            Ordering::Greater => {
638                                degree_row_iter.next().await;
639                            }
640                            Ordering::Less => {
641                                row_iter.next().await;
642                            }
643                            Ordering::Equal => {
644                                let row =
645                                    row_iter.next().await.expect("we matched some(row) above");
646                                let degree_row = degree_row_iter
647                                    .next()
648                                    .await
649                                    .expect("we matched some(degree_row) above");
650                                let pk = row
651                                    .as_ref()
652                                    .project(&self.state.pk_indices)
653                                    .memcmp_serialize(&self.pk_serializer);
654                                let degree_i64 = degree_row
655                                    .datum_at(degree_row.len() - 1)
656                                    .expect("degree should not be NULL");
657                                let inequality_key = self
658                                    .inequality_key_desc
659                                    .as_ref()
660                                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
661                                entry_state
662                                    .insert(
663                                        pk,
664                                        E::encode(&JoinRow::new(
665                                            row.row(),
666                                            degree_i64.into_int64() as u64,
667                                        )),
668                                        inequality_key,
669                                    )
670                                    .with_context(|| self.state.error_context(row.row()))?;
671                            }
672                        },
673                    }
674                }
675            } else {
676                // 1 to 1 row-degree pairing.
677                // Actually it's possible that both the input data table and the degree table missed
678                // some equal number of rows, but let's ignore this case because it should be rare.
679
680                assert_eq!(rows.len(), degree_rows.len());
681
682                #[for_await]
683                for (row, degree_row) in
684                    stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
685                {
686                    let row: KeyedRow<_> = row;
687                    let degree_row: KeyedRow<_> = degree_row;
688
689                    let pk1 = row.key();
690                    let pk2 = degree_row.key();
691                    debug_assert_eq!(
692                        pk1, pk2,
693                        "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
694                    );
695                    let pk = row
696                        .as_ref()
697                        .project(&self.state.pk_indices)
698                        .memcmp_serialize(&self.pk_serializer);
699                    let inequality_key = self
700                        .inequality_key_desc
701                        .as_ref()
702                        .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
703                    let degree_i64 = degree_row
704                        .datum_at(degree_row.len() - 1)
705                        .expect("degree should not be NULL");
706                    entry_state
707                        .insert(
708                            pk,
709                            E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
710                            inequality_key,
711                        )
712                        .with_context(|| self.state.error_context(row.row()))?;
713                }
714            }
715        } else {
716            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
717                &(Bound::Unbounded, Bound::Unbounded);
718            let table_iter = self
719                .state
720                .table
721                .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
722                .await?;
723
724            #[for_await]
725            for entry in table_iter {
726                let row: KeyedRow<_> = entry?;
727                let pk = row
728                    .as_ref()
729                    .project(&self.state.pk_indices)
730                    .memcmp_serialize(&self.pk_serializer);
731                let inequality_key = self
732                    .inequality_key_desc
733                    .as_ref()
734                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
735                entry_state
736                    .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
737                    .with_context(|| self.state.error_context(row.row()))?;
738            }
739        };
740
741        Ok(entry_state)
742    }
743
744    pub async fn flush(
745        &mut self,
746        epoch: EpochPair,
747    ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
748        self.metrics.flush();
749        let state_post_commit = self.state.table.commit(epoch).await?;
750        let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
751            Some(degree_state.table.commit(epoch).await?)
752        } else {
753            None
754        };
755        Ok(JoinHashMapPostCommit {
756            state: state_post_commit,
757            degree_state: degree_state_post_commit,
758            inner: &mut self.inner,
759        })
760    }
761
762    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
763        self.state.table.try_flush().await?;
764        if let Some(degree_state) = &mut self.degree_state {
765            degree_state.table.try_flush().await?;
766        }
767        Ok(())
768    }
769
770    pub fn insert_handle_degree(
771        &mut self,
772        key: &K,
773        value: JoinRow<impl Row>,
774    ) -> StreamExecutorResult<()> {
775        if self.need_degree_table {
776            self.insert(key, value)
777        } else {
778            self.insert_row(key, value.row)
779        }
780    }
781
782    /// Insert a join row
783    pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
784        let pk = self.serialize_pk_from_row(&value.row);
785
786        let inequality_key = self
787            .inequality_key_desc
788            .as_ref()
789            .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
790
791        // TODO(yuhao): avoid this `contains`.
792        // https://github.com/risingwavelabs/risingwave/issues/9233
793        if self.inner.contains(key) {
794            // Update cache
795            let mut entry = self.inner.get_mut(key).expect("checked contains");
796            entry
797                .insert(pk, E::encode(&value), inequality_key)
798                .with_context(|| self.state.error_context(&value.row))?;
799        } else if self.pk_contained_in_jk {
800            // Refill cache when the join key exist in neither cache or storage.
801            self.metrics.insert_cache_miss_count += 1;
802            let mut entry: JoinEntryState<E> = JoinEntryState::default();
803            entry
804                .insert(pk, E::encode(&value), inequality_key)
805                .with_context(|| self.state.error_context(&value.row))?;
806            self.update_state(key, entry.into());
807        }
808
809        // Update the flush buffer.
810        if let Some(degree_state) = self.degree_state.as_mut() {
811            let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
812            self.state.table.insert(row);
813            degree_state.table.insert(degree);
814        } else {
815            self.state.table.insert(value.row);
816        }
817        Ok(())
818    }
819
820    /// Insert a row.
821    /// Used when the side does not need to update degree.
822    pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
823        let join_row = JoinRow::new(&value, 0);
824        self.insert(key, join_row)?;
825        Ok(())
826    }
827
828    pub fn delete_handle_degree(
829        &mut self,
830        key: &K,
831        value: JoinRow<impl Row>,
832    ) -> StreamExecutorResult<()> {
833        if self.need_degree_table {
834            self.delete(key, value)
835        } else {
836            self.delete_row(key, value.row)
837        }
838    }
839
840    /// Delete a join row
841    pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
842        if let Some(mut entry) = self.inner.get_mut(key) {
843            let pk = (&value.row)
844                .project(&self.state.pk_indices)
845                .memcmp_serialize(&self.pk_serializer);
846            let inequality_key = self
847                .inequality_key_desc
848                .as_ref()
849                .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
850            entry
851                .remove(pk, inequality_key.as_ref())
852                .with_context(|| self.state.error_context(&value.row))?;
853        }
854
855        // If no cache maintained, only update the state table.
856        let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
857        self.state.table.delete(row);
858        let degree_state = self.degree_state.as_mut().expect("degree table missing");
859        degree_state.table.delete(degree);
860        Ok(())
861    }
862
863    /// Delete a row
864    /// Used when the side does not need to update degree.
865    pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
866        if let Some(mut entry) = self.inner.get_mut(key) {
867            let pk = (&value)
868                .project(&self.state.pk_indices)
869                .memcmp_serialize(&self.pk_serializer);
870
871            let inequality_key = self
872                .inequality_key_desc
873                .as_ref()
874                .map(|desc| desc.serialize_inequal_key_from_row(&value));
875            entry
876                .remove(pk, inequality_key.as_ref())
877                .with_context(|| self.state.error_context(&value))?;
878        }
879
880        // If no cache maintained, only update the state table.
881        self.state.table.delete(value);
882        Ok(())
883    }
884
885    /// Update a [`JoinEntryState`] into the hash table.
886    pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
887        self.inner.put(key.clone(), HashValueWrapper(Some(state)));
888    }
889
890    /// Evict the cache.
891    pub fn evict(&mut self) {
892        self.inner.evict();
893    }
894
895    /// Cached entry count for this hash table.
896    pub fn entry_count(&self) -> usize {
897        self.inner.len()
898    }
899
900    pub fn null_matched(&self) -> &K::Bitmap {
901        &self.null_matched
902    }
903
904    pub fn table_id(&self) -> TableId {
905        self.state.table.table_id()
906    }
907
908    pub fn join_key_data_types(&self) -> &[DataType] {
909        &self.join_key_data_types
910    }
911
912    /// Return true if the inequality key is null.
913    /// # Panics
914    /// Panics if the inequality key is not set.
915    pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
916        let desc = self
917            .inequality_key_desc
918            .as_ref()
919            .expect("inequality key desc missing");
920        row.datum_at(desc.idx).is_none()
921    }
922
923    /// Serialize the inequality key from a row.
924    /// # Panics
925    /// Panics if the inequality key is not set.
926    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
927        self.inequality_key_desc
928            .as_ref()
929            .expect("inequality key desc missing")
930            .serialize_inequal_key_from_row(&row)
931    }
932
933    pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
934        row.project(&self.state.pk_indices)
935            .memcmp_serialize(&self.pk_serializer)
936    }
937}
938
939#[must_use]
940pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
941    state: StateTablePostCommit<'a, S>,
942    degree_state: Option<StateTablePostCommit<'a, S>>,
943    inner: &'a mut JoinHashMapInner<K, E>,
944}
945
946use risingwave_common::catalog::TableId;
947use risingwave_common_estimate_size::KvSize;
948use thiserror::Error;
949
950use super::*;
951use crate::executor::prelude::{Stream, try_stream};
952
953/// We manages a `HashMap` in memory for all entries belonging to a join key.
954/// When evicted, `cached` does not hold any entries.
955///
956/// If a `JoinEntryState` exists for a join key, the all records under this
957/// join key will be presented in the cache.
958#[derive(Default)]
959pub struct JoinEntryState<E: JoinEncoding> {
960    /// The full copy of the state.
961    cached: JoinRowSet<PkType, E::EncodedRow>,
962    /// Index used for AS OF join. The key is inequal column value. The value is the primary key in `cached`.
963    inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
964    kv_heap_size: KvSize,
965}
966
967impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
968    fn estimated_heap_size(&self) -> usize {
969        // TODO: Add btreemap internal size.
970        // https://github.com/risingwavelabs/risingwave/issues/9713
971        self.kv_heap_size.size()
972    }
973}
974
975#[derive(Error, Debug)]
976pub enum JoinEntryError {
977    #[error("double inserting a join state entry")]
978    Occupied,
979    #[error("removing a join state entry but it is not in the cache")]
980    Remove,
981    #[error("retrieving a pk from the inequality index but it is not in the cache")]
982    InequalIndex,
983}
984
985impl<E: JoinEncoding> JoinEntryState<E> {
986    /// Insert into the cache.
987    pub fn insert(
988        &mut self,
989        key: PkType,
990        value: E::EncodedRow,
991        inequality_key: Option<InequalKeyType>,
992    ) -> Result<&mut E::EncodedRow, JoinEntryError> {
993        let mut removed = false;
994        if !enable_strict_consistency() {
995            // strict consistency is off, let's remove existing (if any) first
996            if let Some(old_value) = self.cached.remove(&key) {
997                if let Some(inequality_key) = inequality_key.as_ref() {
998                    self.remove_pk_from_inequality_index(&key, inequality_key);
999                }
1000                self.kv_heap_size.sub(&key, &old_value);
1001                removed = true;
1002            }
1003        }
1004
1005        self.kv_heap_size.add(&key, &value);
1006
1007        if let Some(inequality_key) = inequality_key {
1008            self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1009        }
1010        let ret = self.cached.try_insert(key.clone(), value);
1011
1012        if !enable_strict_consistency() {
1013            assert!(ret.is_ok(), "we have removed existing entry, if any");
1014            if removed {
1015                // if not silent, we should log the error
1016                consistency_error!(?key, "double inserting a join state entry");
1017            }
1018        }
1019
1020        ret.map_err(|_| JoinEntryError::Occupied)
1021    }
1022
1023    /// Delete from the cache.
1024    pub fn remove(
1025        &mut self,
1026        pk: PkType,
1027        inequality_key: Option<&InequalKeyType>,
1028    ) -> Result<(), JoinEntryError> {
1029        if let Some(value) = self.cached.remove(&pk) {
1030            self.kv_heap_size.sub(&pk, &value);
1031            if let Some(inequality_key) = inequality_key {
1032                self.remove_pk_from_inequality_index(&pk, inequality_key);
1033            }
1034            Ok(())
1035        } else if enable_strict_consistency() {
1036            Err(JoinEntryError::Remove)
1037        } else {
1038            consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1039            Ok(())
1040        }
1041    }
1042
1043    fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1044        if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1045            if pk_set.remove(pk).is_none() {
1046                if enable_strict_consistency() {
1047                    panic!("removing a pk that it not in the inequality index");
1048                } else {
1049                    consistency_error!(?pk, "removing a pk that it not in the inequality index");
1050                };
1051            } else {
1052                self.kv_heap_size.sub(pk, &());
1053            }
1054            if pk_set.is_empty() {
1055                self.inequality_index.remove(inequality_key);
1056            }
1057        }
1058    }
1059
1060    fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1061        if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1062            let pk_size = pk.estimated_size();
1063            if pk_set.try_insert(pk, ()).is_err() {
1064                if enable_strict_consistency() {
1065                    panic!("inserting a pk that it already in the inequality index");
1066                } else {
1067                    consistency_error!("inserting a pk that it already in the inequality index");
1068                };
1069            } else {
1070                self.kv_heap_size.add_size(pk_size);
1071            }
1072        } else {
1073            let mut pk_set = JoinRowSet::default();
1074            pk_set.try_insert(pk, ()).expect("pk set should be empty");
1075            self.inequality_index
1076                .try_insert(inequality_key, pk_set)
1077                .expect("pk set should be empty");
1078        }
1079    }
1080
1081    pub fn get(
1082        &self,
1083        pk: &PkType,
1084        data_types: &[DataType],
1085    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1086        self.cached
1087            .get(pk)
1088            .map(|encoded| encoded.decode(data_types))
1089    }
1090
1091    /// Note: the first item in the tuple is the mutable reference to the value in this entry, while
1092    /// the second item is the decoded value. To mutate the degree, one **must not** forget to apply
1093    /// the changes to the first item.
1094    ///
1095    /// WARNING: Should not change the heap size of `StateValueType` with the mutable reference.
1096    pub fn values_mut<'a>(
1097        &'a mut self,
1098        data_types: &'a [DataType],
1099    ) -> impl Iterator<
1100        Item = (
1101            &'a mut E::EncodedRow,
1102            StreamExecutorResult<JoinRow<E::DecodedRow>>,
1103        ),
1104    > + 'a {
1105        self.cached.values_mut().map(|encoded| {
1106            let decoded = encoded.decode(data_types);
1107            (encoded, decoded)
1108        })
1109    }
1110
1111    pub fn len(&self) -> usize {
1112        self.cached.len()
1113    }
1114
1115    /// Range scan the cache using the inequality index.
1116    pub fn range_by_inequality<'a, R>(
1117        &'a self,
1118        range: R,
1119        data_types: &'a [DataType],
1120    ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<E::DecodedRow>>> + 'a
1121    where
1122        R: RangeBounds<InequalKeyType> + 'a,
1123    {
1124        self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1125            pk_set
1126                .keys()
1127                .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1128        })
1129    }
1130
1131    /// Get the records whose inequality key upper bound satisfy the given bound.
1132    pub fn upper_bound_by_inequality<'a>(
1133        &'a self,
1134        bound: Bound<&InequalKeyType>,
1135        data_types: &'a [DataType],
1136    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1137        if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1138            if let Some(pk) = pk_set.first_key_sorted() {
1139                self.get_by_indexed_pk(pk, data_types)
1140            } else {
1141                panic!("pk set for a index record must has at least one element");
1142            }
1143        } else {
1144            None
1145        }
1146    }
1147
1148    pub fn get_by_indexed_pk(
1149        &self,
1150        pk: &PkType,
1151        data_types: &[DataType],
1152    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>>
1153where {
1154        if let Some(value) = self.cached.get(pk) {
1155            Some(value.decode(data_types))
1156        } else if enable_strict_consistency() {
1157            Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1158        } else {
1159            consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1160            None
1161        }
1162    }
1163
1164    /// Get the records whose inequality key lower bound satisfy the given bound.
1165    pub fn lower_bound_by_inequality<'a>(
1166        &'a self,
1167        bound: Bound<&InequalKeyType>,
1168        data_types: &'a [DataType],
1169    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1170        if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1171            if let Some(pk) = pk_set.first_key_sorted() {
1172                self.get_by_indexed_pk(pk, data_types)
1173            } else {
1174                panic!("pk set for a index record must has at least one element");
1175            }
1176        } else {
1177            None
1178        }
1179    }
1180
1181    pub fn get_first_by_inequality<'a>(
1182        &'a self,
1183        inequality_key: &InequalKeyType,
1184        data_types: &'a [DataType],
1185    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1186        if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1187            if let Some(pk) = pk_set.first_key_sorted() {
1188                self.get_by_indexed_pk(pk, data_types)
1189            } else {
1190                panic!("pk set for a index record must has at least one element");
1191            }
1192        } else {
1193            None
1194        }
1195    }
1196
1197    pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1198        &self.inequality_index
1199    }
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204    use itertools::Itertools;
1205    use risingwave_common::array::*;
1206    use risingwave_common::types::ScalarRefImpl;
1207    use risingwave_common::util::iter_util::ZipEqDebug;
1208
1209    use super::*;
1210    use crate::executor::MemoryEncoding;
1211
1212    fn insert_chunk<E: JoinEncoding>(
1213        managed_state: &mut JoinEntryState<E>,
1214        pk_indices: &[usize],
1215        col_types: &[DataType],
1216        inequality_key_idx: Option<usize>,
1217        data_chunk: &DataChunk,
1218    ) {
1219        let pk_col_type = pk_indices
1220            .iter()
1221            .map(|idx| col_types[*idx].clone())
1222            .collect_vec();
1223        let pk_serializer =
1224            OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1225        let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1226        let inequality_key_serializer = inequality_key_type
1227            .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1228        for row_ref in data_chunk.rows() {
1229            let row: OwnedRow = row_ref.into_owned_row();
1230            let value_indices = (0..row.len() - 1).collect_vec();
1231            let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1232            // Pk is only a `i64` here, so encoding method does not matter.
1233            let pk = OwnedRow::new(pk)
1234                .project(&value_indices)
1235                .memcmp_serialize(&pk_serializer);
1236            let inequality_key = inequality_key_idx.map(|idx| {
1237                (&row)
1238                    .project(&[idx])
1239                    .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1240            });
1241            let join_row = JoinRow { row, degree: 0 };
1242            managed_state
1243                .insert(pk, E::encode(&join_row), inequality_key)
1244                .unwrap();
1245        }
1246    }
1247
1248    fn check<E: JoinEncoding>(
1249        managed_state: &mut JoinEntryState<E>,
1250        col_types: &[DataType],
1251        col1: &[i64],
1252        col2: &[i64],
1253    ) {
1254        for ((_, matched_row), (d1, d2)) in managed_state
1255            .values_mut(col_types)
1256            .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1257        {
1258            let matched_row = matched_row.unwrap();
1259            assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
1260            assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
1261            assert_eq!(matched_row.degree, 0);
1262        }
1263    }
1264
1265    #[tokio::test]
1266    async fn test_managed_join_state() {
1267        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1268        let col_types = vec![DataType::Int64, DataType::Int64];
1269        let pk_indices = [0];
1270
1271        let col1 = [3, 2, 1];
1272        let col2 = [4, 5, 6];
1273        let data_chunk1 = DataChunk::from_pretty(
1274            "I I
1275             3 4
1276             2 5
1277             1 6",
1278        );
1279
1280        // `Vec` in state
1281        insert_chunk::<MemoryEncoding>(
1282            &mut managed_state,
1283            &pk_indices,
1284            &col_types,
1285            None,
1286            &data_chunk1,
1287        );
1288        check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1289
1290        // `BtreeMap` in state
1291        let col1 = [1, 2, 3, 4, 5];
1292        let col2 = [6, 5, 4, 9, 8];
1293        let data_chunk2 = DataChunk::from_pretty(
1294            "I I
1295             5 8
1296             4 9",
1297        );
1298        insert_chunk(
1299            &mut managed_state,
1300            &pk_indices,
1301            &col_types,
1302            None,
1303            &data_chunk2,
1304        );
1305        check(&mut managed_state, &col_types, &col1, &col2);
1306    }
1307
1308    #[tokio::test]
1309    async fn test_managed_join_state_w_inequality_index() {
1310        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1311        let col_types = vec![DataType::Int64, DataType::Int64];
1312        let pk_indices = [0];
1313        let inequality_key_idx = Some(1);
1314        let inequality_key_serializer =
1315            OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1316
1317        let col1 = [3, 2, 1];
1318        let col2 = [4, 5, 5];
1319        let data_chunk1 = DataChunk::from_pretty(
1320            "I I
1321             3 4
1322             2 5
1323             1 5",
1324        );
1325
1326        // `Vec` in state
1327        insert_chunk(
1328            &mut managed_state,
1329            &pk_indices,
1330            &col_types,
1331            inequality_key_idx,
1332            &data_chunk1,
1333        );
1334        check(&mut managed_state, &col_types, &col1, &col2);
1335        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1336            .memcmp_serialize(&inequality_key_serializer);
1337        let row = managed_state
1338            .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1339            .unwrap()
1340            .unwrap();
1341        assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1342        let row = managed_state
1343            .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1344            .unwrap()
1345            .unwrap();
1346        assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1347
1348        // `BtreeMap` in state
1349        let col1 = [1, 2, 3, 4, 5];
1350        let col2 = [5, 5, 4, 4, 8];
1351        let data_chunk2 = DataChunk::from_pretty(
1352            "I I
1353             5 8
1354             4 4",
1355        );
1356        insert_chunk(
1357            &mut managed_state,
1358            &pk_indices,
1359            &col_types,
1360            inequality_key_idx,
1361            &data_chunk2,
1362        );
1363        check(&mut managed_state, &col_types, &col1, &col2);
1364
1365        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1366            .memcmp_serialize(&inequality_key_serializer);
1367        let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1368        assert!(row.is_none());
1369    }
1370}