risingwave_stream/executor/join/
hash_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use std::alloc::Global;
15use std::cmp::Ordering;
16use std::ops::{Bound, Deref, DerefMut, RangeBounds};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use futures::future::{join, try_join};
21use futures::{StreamExt, pin_mut, stream};
22use futures_async_stream::for_await;
23use join_row_set::JoinRowSet;
24use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
25use risingwave_common::bitmap::Bitmap;
26use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
27use risingwave_common::metrics::LabelGuardedIntCounter;
28use risingwave_common::row::{OwnedRow, Row, RowExt, once};
29use risingwave_common::types::{DataType, ScalarImpl};
30use risingwave_common::util::epoch::EpochPair;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common::util::row_serde::OrderedRowSerde;
33use risingwave_common::util::sort_util::OrderType;
34use risingwave_common_estimate_size::EstimateSize;
35use risingwave_storage::StateStore;
36use risingwave_storage::store::PrefetchOptions;
37use risingwave_storage::table::KeyedRow;
38use thiserror_ext::AsReport;
39
40use super::row::{CachedJoinRow, DegreeType};
41use crate::cache::ManagedLruCache;
42use crate::common::metrics::MetricsInfo;
43use crate::common::table::state_table::{StateTable, StateTablePostCommit};
44use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
45use crate::executor::error::StreamExecutorResult;
46use crate::executor::join::row::JoinRow;
47use crate::executor::monitor::StreamingMetrics;
48use crate::executor::{JoinEncoding, StreamExecutorError};
49use crate::task::{ActorId, AtomicU64Ref, FragmentId};
50
51/// Memcomparable encoding.
52type PkType = Vec<u8>;
53type InequalKeyType = Vec<u8>;
54
55pub type HashValueType<E> = Box<JoinEntryState<E>>;
56
57impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
58    fn estimated_heap_size(&self) -> usize {
59        self.as_ref().estimated_heap_size()
60    }
61}
62
63/// The wrapper for [`JoinEntryState`] which should be `Some` most of the time in the hash table.
64///
65/// When the executor is operating on the specific entry of the map, it can hold the ownership of
66/// the entry by taking the value out of the `Option`, instead of holding a mutable reference to the
67/// map, which can make the compiler happy.
68struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
69
70pub(crate) enum CacheResult<E: JoinEncoding> {
71    NeverMatch,            // Will never match, will not be in cache at all.
72    Miss,                  // Cache-miss
73    Hit(HashValueType<E>), // Cache-hit
74}
75
76impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
77    fn estimated_heap_size(&self) -> usize {
78        self.0.estimated_heap_size()
79    }
80}
81
82impl<E: JoinEncoding> HashValueWrapper<E> {
83    const MESSAGE: &'static str = "the state should always be `Some`";
84
85    /// Take the value out of the wrapper. Panic if the value is `None`.
86    pub fn take(&mut self) -> HashValueType<E> {
87        self.0.take().expect(Self::MESSAGE)
88    }
89}
90
91impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
92    type Target = HashValueType<E>;
93
94    fn deref(&self) -> &Self::Target {
95        self.0.as_ref().expect(Self::MESSAGE)
96    }
97}
98
99impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
100    fn deref_mut(&mut self) -> &mut Self::Target {
101        self.0.as_mut().expect(Self::MESSAGE)
102    }
103}
104
105type JoinHashMapInner<K, E> =
106    ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher, SharedStatsAlloc<Global>>;
107
108pub struct JoinHashMapMetrics {
109    /// Basic information
110    /// How many times have we hit the cache of join executor
111    lookup_miss_count: usize,
112    total_lookup_count: usize,
113    /// How many times have we miss the cache when insert row
114    insert_cache_miss_count: usize,
115
116    // Metrics
117    join_lookup_total_count_metric: LabelGuardedIntCounter,
118    join_lookup_miss_count_metric: LabelGuardedIntCounter,
119    join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
120}
121
122impl JoinHashMapMetrics {
123    pub fn new(
124        metrics: &StreamingMetrics,
125        actor_id: ActorId,
126        fragment_id: FragmentId,
127        side: &'static str,
128        join_table_id: u32,
129    ) -> Self {
130        let actor_id = actor_id.to_string();
131        let fragment_id = fragment_id.to_string();
132        let join_table_id = join_table_id.to_string();
133        let join_lookup_total_count_metric = metrics
134            .join_lookup_total_count
135            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
136        let join_lookup_miss_count_metric = metrics
137            .join_lookup_miss_count
138            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
139        let join_insert_cache_miss_count_metrics = metrics
140            .join_insert_cache_miss_count
141            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
142
143        Self {
144            lookup_miss_count: 0,
145            total_lookup_count: 0,
146            insert_cache_miss_count: 0,
147            join_lookup_total_count_metric,
148            join_lookup_miss_count_metric,
149            join_insert_cache_miss_count_metrics,
150        }
151    }
152
153    pub fn flush(&mut self) {
154        self.join_lookup_total_count_metric
155            .inc_by(self.total_lookup_count as u64);
156        self.join_lookup_miss_count_metric
157            .inc_by(self.lookup_miss_count as u64);
158        self.join_insert_cache_miss_count_metrics
159            .inc_by(self.insert_cache_miss_count as u64);
160        self.total_lookup_count = 0;
161        self.lookup_miss_count = 0;
162        self.insert_cache_miss_count = 0;
163    }
164}
165
166/// Inequality key description for `AsOf` join.
167struct InequalityKeyDesc {
168    idx: usize,
169    serializer: OrderedRowSerde,
170}
171
172impl InequalityKeyDesc {
173    /// Serialize the inequality key from a row.
174    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
175        let indices = vec![self.idx];
176        let inequality_key = row.project(&indices);
177        inequality_key.memcmp_serialize(&self.serializer)
178    }
179}
180
181pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
182    /// Store the join states.
183    inner: JoinHashMapInner<K, E>,
184    /// Data types of the join key columns
185    join_key_data_types: Vec<DataType>,
186    /// Null safe bitmap for each join pair
187    null_matched: K::Bitmap,
188    /// The memcomparable serializer of primary key.
189    pk_serializer: OrderedRowSerde,
190    /// State table. Contains the data from upstream.
191    state: TableInner<S>,
192    /// Degree table.
193    ///
194    /// The degree is generated from the hash join executor.
195    /// Each row in `state` has a corresponding degree in `degree state`.
196    /// A degree value `d` in for a row means the row has `d` matched row in the other join side.
197    ///
198    /// It will only be used when needed in a side.
199    ///
200    /// - Full Outer: both side
201    /// - Left Outer/Semi/Anti: left side
202    /// - Right Outer/Semi/Anti: right side
203    /// - Inner: neither side.
204    ///
205    /// Should be set to `None` if `need_degree_table` was set to `false`.
206    ///
207    /// The degree of each row will tell us if we need to emit `NULL` for the row.
208    /// For instance, given `lhs LEFT JOIN rhs`,
209    /// If the degree of a row in `lhs` is 0, it means the row does not have a match in `rhs`.
210    /// If the degree of a row in `lhs` is 2, it means the row has two matches in `rhs`.
211    /// Now, when emitting the result of the join, we need to emit `NULL` for the row in `lhs` if
212    /// the degree is 0.
213    ///
214    /// Why don't just use a boolean value instead of a degree count?
215    /// Consider the case where we delete a matched record from `rhs`.
216    /// Since we can delete a record,
217    /// there must have been a record in `rhs` that matched the record in `lhs`.
218    /// So this value is `true`.
219    /// But we don't know how many records are matched after removing this record,
220    /// since we only stored a boolean value rather than the count.
221    /// Hence we need to store the count of matched records.
222    degree_state: Option<TableInner<S>>,
223    // TODO(kwannoel): Make this `const` instead.
224    /// If degree table is need
225    need_degree_table: bool,
226    /// Pk is part of the join key.
227    pk_contained_in_jk: bool,
228    /// Inequality key description for `AsOf` join.
229    inequality_key_desc: Option<InequalityKeyDesc>,
230    /// Metrics of the hash map
231    metrics: JoinHashMapMetrics,
232    _marker: std::marker::PhantomData<E>,
233}
234
235impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
236    pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
237        (&self.state.order_key_indices, &mut self.degree_state)
238    }
239
240    /// NOTE(kwannoel): This allows us to concurrently stream records from the `state_table`,
241    /// and update the degree table, without using `unsafe` code.
242    ///
243    /// This is because we obtain separate references to separate parts of the `JoinHashMap`,
244    /// instead of reusing the same reference to `JoinHashMap` for concurrent read access to `state_table`,
245    /// and write access to the degree table.
246    pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
247        &'a mut self,
248        key: &'a K,
249    ) -> StreamExecutorResult<(
250        impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
251        &'a [usize],
252        &'a mut Option<TableInner<S>>,
253    )> {
254        let degree_state = &mut self.degree_state;
255        let (order_key_indices, pk_indices, state_table) = (
256            &self.state.order_key_indices,
257            &self.state.pk_indices,
258            &mut self.state.table,
259        );
260        let degrees = if let Some(degree_state) = degree_state {
261            Some(fetch_degrees(key, &self.join_key_data_types, &degree_state.table).await?)
262        } else {
263            None
264        };
265        let stream = into_stream(
266            &self.join_key_data_types,
267            pk_indices,
268            &self.pk_serializer,
269            state_table,
270            key,
271            degrees,
272        );
273        Ok((stream, order_key_indices, &mut self.degree_state))
274    }
275}
276
277#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
278pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
279    join_key_data_types: &'a [DataType],
280    pk_indices: &'a [usize],
281    pk_serializer: &'a OrderedRowSerde,
282    state_table: &'a StateTable<S>,
283    key: &'a K,
284    degrees: Option<Vec<DegreeType>>,
285) {
286    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
287    let decoded_key = key.deserialize(join_key_data_types)?;
288    let table_iter = state_table
289        .iter_with_prefix(&decoded_key, sub_range, PrefetchOptions::default())
290        .await?;
291
292    #[for_await]
293    for (i, entry) in table_iter.enumerate() {
294        let encoded_row = entry?;
295        let encoded_pk = encoded_row
296            .as_ref()
297            .project(pk_indices)
298            .memcmp_serialize(pk_serializer);
299        let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
300        yield (encoded_pk, join_row);
301    }
302}
303
304/// We use this to fetch ALL degrees into memory.
305/// We use this instead of a streaming interface.
306/// It is necessary because we must update the `degree_state_table` concurrently.
307/// If we obtain the degrees in a stream,
308/// we will need to hold an immutable reference to the state table for the entire lifetime,
309/// preventing us from concurrently updating the state table.
310///
311/// The cost of fetching all degrees upfront is acceptable. We currently already do so
312/// in `fetch_cached_state`.
313/// The memory use should be limited since we only store a u64.
314///
315/// Let's say we have amplification of 1B, we will have 1B * 8 bytes ~= 8GB
316///
317/// We can also have further optimization, to permit breaking the streaming update,
318/// to flush the in-memory degrees, if this is proven to have high memory consumption.
319///
320/// TODO(kwannoel): Perhaps we can cache these separately from matched rows too.
321/// Because matched rows may occupy a larger capacity.
322///
323/// Argument for this:
324/// We only hit this when cache miss. When cache miss, we will have this as one off cost.
325/// Keeping this cached separately from matched rows is beneficial.
326/// Then we can evict matched rows, without touching the degrees.
327async fn fetch_degrees<K: HashKey, S: StateStore>(
328    key: &K,
329    join_key_data_types: &[DataType],
330    degree_state_table: &StateTable<S>,
331) -> StreamExecutorResult<Vec<DegreeType>> {
332    let key = key.deserialize(join_key_data_types)?;
333    let mut degrees = vec![];
334    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
335    let table_iter = degree_state_table
336        .iter_with_prefix(key, sub_range, PrefetchOptions::default())
337        .await?;
338    #[for_await]
339    for entry in table_iter {
340        let degree_row = entry?;
341        let degree_i64 = degree_row
342            .datum_at(degree_row.len() - 1)
343            .expect("degree should not be NULL");
344        degrees.push(degree_i64.into_int64() as u64);
345    }
346    Ok(degrees)
347}
348
349// NOTE(kwannoel): This is not really specific to `TableInner`.
350// A degree table is `TableInner`, a `TableInner` might not be a degree table.
351// Hence we don't specify it in its impl block.
352pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
353    order_key_indices: &[usize],
354    degree_state: &mut TableInner<S>,
355    matched_row: &mut JoinRow<impl Row>,
356) {
357    let old_degree_row = (&matched_row.row)
358        .project(order_key_indices)
359        .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
360    if INCREMENT {
361        matched_row.degree += 1;
362    } else {
363        // DECREMENT
364        matched_row.degree -= 1;
365    }
366    let new_degree_row = (&matched_row.row)
367        .project(order_key_indices)
368        .chain(once(Some(ScalarImpl::Int64(matched_row.degree as i64))));
369    degree_state.table.update(old_degree_row, new_degree_row);
370}
371
372pub struct TableInner<S: StateStore> {
373    /// Indices of the (cache) pk in a state row
374    pk_indices: Vec<usize>,
375    /// Indices of the join key in a state row
376    join_key_indices: Vec<usize>,
377    /// The order key of the join side has the following format:
378    /// | `join_key` ... | pk ... |
379    /// Where `join_key` contains all the columns not in the pk.
380    /// It should be a superset of the pk.
381    order_key_indices: Vec<usize>,
382    pub(crate) table: StateTable<S>,
383}
384
385impl<S: StateStore> TableInner<S> {
386    pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
387        let order_key_indices = table.pk_indices().to_vec();
388        Self {
389            pk_indices,
390            join_key_indices,
391            order_key_indices,
392            table,
393        }
394    }
395
396    fn error_context(&self, row: &impl Row) -> String {
397        let pk = row.project(&self.pk_indices);
398        let jk = row.project(&self.join_key_indices);
399        format!(
400            "join key: {}, pk: {}, row: {}, state_table_id: {}",
401            jk.display(),
402            pk.display(),
403            row.display(),
404            self.table.table_id()
405        )
406    }
407}
408
409impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
410    /// Create a [`JoinHashMap`] with the given LRU capacity.
411    #[allow(clippy::too_many_arguments)]
412    pub fn new(
413        watermark_sequence: AtomicU64Ref,
414        join_key_data_types: Vec<DataType>,
415        state_join_key_indices: Vec<usize>,
416        state_all_data_types: Vec<DataType>,
417        state_table: StateTable<S>,
418        state_pk_indices: Vec<usize>,
419        degree_state: Option<TableInner<S>>,
420        null_matched: K::Bitmap,
421        pk_contained_in_jk: bool,
422        inequality_key_idx: Option<usize>,
423        metrics: Arc<StreamingMetrics>,
424        actor_id: ActorId,
425        fragment_id: FragmentId,
426        side: &'static str,
427    ) -> Self {
428        let alloc = StatsAlloc::new(Global).shared();
429        // TODO: unify pk encoding with state table.
430        let pk_data_types = state_pk_indices
431            .iter()
432            .map(|i| state_all_data_types[*i].clone())
433            .collect();
434        let pk_serializer = OrderedRowSerde::new(
435            pk_data_types,
436            vec![OrderType::ascending(); state_pk_indices.len()],
437        );
438
439        let inequality_key_desc = inequality_key_idx.map(|idx| {
440            let serializer = OrderedRowSerde::new(
441                vec![state_all_data_types[idx].clone()],
442                vec![OrderType::ascending()],
443            );
444            InequalityKeyDesc { idx, serializer }
445        });
446
447        let join_table_id = state_table.table_id();
448        let state = TableInner {
449            pk_indices: state_pk_indices,
450            join_key_indices: state_join_key_indices,
451            order_key_indices: state_table.pk_indices().to_vec(),
452            table: state_table,
453        };
454
455        let need_degree_table = degree_state.is_some();
456
457        let metrics_info = MetricsInfo::new(
458            metrics.clone(),
459            join_table_id,
460            actor_id,
461            format!("hash join {}", side),
462        );
463
464        let cache = ManagedLruCache::unbounded_with_hasher_in(
465            watermark_sequence,
466            metrics_info,
467            PrecomputedBuildHasher,
468            alloc,
469        );
470
471        Self {
472            inner: cache,
473            join_key_data_types,
474            null_matched,
475            pk_serializer,
476            state,
477            degree_state,
478            need_degree_table,
479            pk_contained_in_jk,
480            inequality_key_desc,
481            metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
482            _marker: std::marker::PhantomData,
483        }
484    }
485
486    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
487        self.state.table.init_epoch(epoch).await?;
488        if let Some(degree_state) = &mut self.degree_state {
489            degree_state.table.init_epoch(epoch).await?;
490        }
491        Ok(())
492    }
493}
494
495impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
496    pub async fn post_yield_barrier(
497        self,
498        vnode_bitmap: Option<Arc<Bitmap>>,
499    ) -> StreamExecutorResult<Option<bool>> {
500        let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
501        if let Some(degree_state) = self.degree_state {
502            let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
503        }
504        let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
505        if cache_may_stale.unwrap_or(false) {
506            self.inner.clear();
507        }
508        Ok(cache_may_stale)
509    }
510}
511impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
512    pub fn update_watermark(&mut self, watermark: ScalarImpl) {
513        // TODO: remove data in cache.
514        self.state.table.update_watermark(watermark.clone());
515        if let Some(degree_state) = &mut self.degree_state {
516            degree_state.table.update_watermark(watermark);
517        }
518    }
519
520    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
521    /// `update_state` after some operations to put the state back.
522    ///
523    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
524    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
525    /// returned.
526    ///
527    /// Note: This will NOT remove anything from remote storage.
528    pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
529        self.metrics.total_lookup_count += 1;
530        if self.inner.contains(key) {
531            tracing::trace!("hit cache for join key: {:?}", key);
532            // Do not update the LRU statistics here with `peek_mut` since we will put the state
533            // back.
534            let mut state = self.inner.peek_mut(key).expect("checked contains");
535            CacheResult::Hit(state.take())
536        } else {
537            tracing::trace!("miss cache for join key: {:?}", key);
538            CacheResult::Miss
539        }
540    }
541
542    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
543    /// `update_state` after some operations to put the state back.
544    ///
545    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
546    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
547    /// returned.
548    ///
549    /// Note: This will NOT remove anything from remote storage.
550    pub async fn take_state(&mut self, key: &K) -> StreamExecutorResult<HashValueType<E>> {
551        self.metrics.total_lookup_count += 1;
552        let state = if self.inner.contains(key) {
553            // Do not update the LRU statistics here with `peek_mut` since we will put the state
554            // back.
555            let mut state = self.inner.peek_mut(key).unwrap();
556            state.take()
557        } else {
558            self.metrics.lookup_miss_count += 1;
559            self.fetch_cached_state(key).await?.into()
560        };
561        Ok(state)
562    }
563
564    /// Fetch cache from the state store. Should only be called if the key does not exist in memory.
565    /// Will return a empty `JoinEntryState` even when state does not exist in remote.
566    async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState<E>> {
567        let key = key.deserialize(&self.join_key_data_types)?;
568
569        let mut entry_state: JoinEntryState<E> = JoinEntryState::default();
570
571        if self.need_degree_table {
572            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
573                &(Bound::Unbounded, Bound::Unbounded);
574            let table_iter_fut = self.state.table.iter_keyed_row_with_prefix(
575                &key,
576                sub_range,
577                PrefetchOptions::default(),
578            );
579            let degree_state = self.degree_state.as_ref().unwrap();
580            let degree_table_iter_fut = degree_state.table.iter_keyed_row_with_prefix(
581                &key,
582                sub_range,
583                PrefetchOptions::default(),
584            );
585
586            let (table_iter, degree_table_iter) =
587                try_join(table_iter_fut, degree_table_iter_fut).await?;
588
589            let mut pinned_table_iter = std::pin::pin!(table_iter);
590            let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
591
592            // For better tolerating inconsistent stream, we have to first buffer all rows and
593            // degree rows, and check the number of them, then iterate on them.
594            let mut rows = vec![];
595            let mut degree_rows = vec![];
596            let mut inconsistency_happened = false;
597            loop {
598                let (row, degree_row) =
599                    join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
600                let (row, degree_row) = match (row, degree_row) {
601                    (None, None) => break,
602                    (None, Some(_)) => {
603                        inconsistency_happened = true;
604                        consistency_panic!(
605                            "mismatched row and degree table of join key: {:?}, degree table has more rows",
606                            &key
607                        );
608                        break;
609                    }
610                    (Some(_), None) => {
611                        inconsistency_happened = true;
612                        consistency_panic!(
613                            "mismatched row and degree table of join key: {:?}, input table has more rows",
614                            &key
615                        );
616                        break;
617                    }
618                    (Some(r), Some(d)) => (r, d),
619                };
620
621                let row = row?;
622                let degree_row = degree_row?;
623                rows.push(row);
624                degree_rows.push(degree_row);
625            }
626
627            if inconsistency_happened {
628                // Pk-based row-degree pairing.
629                assert_ne!(rows.len(), degree_rows.len());
630
631                let row_iter = stream::iter(rows.into_iter()).peekable();
632                let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
633                pin_mut!(row_iter);
634                pin_mut!(degree_row_iter);
635
636                loop {
637                    match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
638                        (None, _) | (_, None) => break,
639                        (Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
640                            Ordering::Greater => {
641                                degree_row_iter.next().await;
642                            }
643                            Ordering::Less => {
644                                row_iter.next().await;
645                            }
646                            Ordering::Equal => {
647                                let row =
648                                    row_iter.next().await.expect("we matched some(row) above");
649                                let degree_row = degree_row_iter
650                                    .next()
651                                    .await
652                                    .expect("we matched some(degree_row) above");
653                                let pk = row
654                                    .as_ref()
655                                    .project(&self.state.pk_indices)
656                                    .memcmp_serialize(&self.pk_serializer);
657                                let degree_i64 = degree_row
658                                    .datum_at(degree_row.len() - 1)
659                                    .expect("degree should not be NULL");
660                                let inequality_key = self
661                                    .inequality_key_desc
662                                    .as_ref()
663                                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
664                                entry_state
665                                    .insert(
666                                        pk,
667                                        E::encode(&JoinRow::new(
668                                            row.row(),
669                                            degree_i64.into_int64() as u64,
670                                        )),
671                                        inequality_key,
672                                    )
673                                    .with_context(|| self.state.error_context(row.row()))?;
674                            }
675                        },
676                    }
677                }
678            } else {
679                // 1 to 1 row-degree pairing.
680                // Actually it's possible that both the input data table and the degree table missed
681                // some equal number of rows, but let's ignore this case because it should be rare.
682
683                assert_eq!(rows.len(), degree_rows.len());
684
685                #[for_await]
686                for (row, degree_row) in
687                    stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
688                {
689                    let row: KeyedRow<_> = row;
690                    let degree_row: KeyedRow<_> = degree_row;
691
692                    let pk1 = row.key();
693                    let pk2 = degree_row.key();
694                    debug_assert_eq!(
695                        pk1, pk2,
696                        "mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
697                    );
698                    let pk = row
699                        .as_ref()
700                        .project(&self.state.pk_indices)
701                        .memcmp_serialize(&self.pk_serializer);
702                    let inequality_key = self
703                        .inequality_key_desc
704                        .as_ref()
705                        .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
706                    let degree_i64 = degree_row
707                        .datum_at(degree_row.len() - 1)
708                        .expect("degree should not be NULL");
709                    entry_state
710                        .insert(
711                            pk,
712                            E::encode(&JoinRow::new(row.row(), degree_i64.into_int64() as u64)),
713                            inequality_key,
714                        )
715                        .with_context(|| self.state.error_context(row.row()))?;
716                }
717            }
718        } else {
719            let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
720                &(Bound::Unbounded, Bound::Unbounded);
721            let table_iter = self
722                .state
723                .table
724                .iter_keyed_row_with_prefix(&key, sub_range, PrefetchOptions::default())
725                .await?;
726
727            #[for_await]
728            for entry in table_iter {
729                let row: KeyedRow<_> = entry?;
730                let pk = row
731                    .as_ref()
732                    .project(&self.state.pk_indices)
733                    .memcmp_serialize(&self.pk_serializer);
734                let inequality_key = self
735                    .inequality_key_desc
736                    .as_ref()
737                    .map(|desc| desc.serialize_inequal_key_from_row(row.row()));
738                entry_state
739                    .insert(pk, E::encode(&JoinRow::new(row.row(), 0)), inequality_key)
740                    .with_context(|| self.state.error_context(row.row()))?;
741            }
742        };
743
744        Ok(entry_state)
745    }
746
747    pub async fn flush(
748        &mut self,
749        epoch: EpochPair,
750    ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
751        self.metrics.flush();
752        let state_post_commit = self.state.table.commit(epoch).await?;
753        let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
754            Some(degree_state.table.commit(epoch).await?)
755        } else {
756            None
757        };
758        Ok(JoinHashMapPostCommit {
759            state: state_post_commit,
760            degree_state: degree_state_post_commit,
761            inner: &mut self.inner,
762        })
763    }
764
765    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
766        self.state.table.try_flush().await?;
767        if let Some(degree_state) = &mut self.degree_state {
768            degree_state.table.try_flush().await?;
769        }
770        Ok(())
771    }
772
773    pub fn insert_handle_degree(
774        &mut self,
775        key: &K,
776        value: JoinRow<impl Row>,
777    ) -> StreamExecutorResult<()> {
778        if self.need_degree_table {
779            self.insert(key, value)
780        } else {
781            self.insert_row(key, value.row)
782        }
783    }
784
785    /// Insert a join row
786    pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
787        let pk = self.serialize_pk_from_row(&value.row);
788
789        let inequality_key = self
790            .inequality_key_desc
791            .as_ref()
792            .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
793
794        // TODO(yuhao): avoid this `contains`.
795        // https://github.com/risingwavelabs/risingwave/issues/9233
796        if self.inner.contains(key) {
797            // Update cache
798            let mut entry = self.inner.get_mut(key).expect("checked contains");
799            entry
800                .insert(pk, E::encode(&value), inequality_key)
801                .with_context(|| self.state.error_context(&value.row))?;
802        } else if self.pk_contained_in_jk {
803            // Refill cache when the join key exist in neither cache or storage.
804            self.metrics.insert_cache_miss_count += 1;
805            let mut entry: JoinEntryState<E> = JoinEntryState::default();
806            entry
807                .insert(pk, E::encode(&value), inequality_key)
808                .with_context(|| self.state.error_context(&value.row))?;
809            self.update_state(key, entry.into());
810        }
811
812        // Update the flush buffer.
813        if let Some(degree_state) = self.degree_state.as_mut() {
814            let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
815            self.state.table.insert(row);
816            degree_state.table.insert(degree);
817        } else {
818            self.state.table.insert(value.row);
819        }
820        Ok(())
821    }
822
823    /// Insert a row.
824    /// Used when the side does not need to update degree.
825    pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
826        let join_row = JoinRow::new(&value, 0);
827        self.insert(key, join_row)?;
828        Ok(())
829    }
830
831    pub fn delete_handle_degree(
832        &mut self,
833        key: &K,
834        value: JoinRow<impl Row>,
835    ) -> StreamExecutorResult<()> {
836        if self.need_degree_table {
837            self.delete(key, value)
838        } else {
839            self.delete_row(key, value.row)
840        }
841    }
842
843    /// Delete a join row
844    pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
845        if let Some(mut entry) = self.inner.get_mut(key) {
846            let pk = (&value.row)
847                .project(&self.state.pk_indices)
848                .memcmp_serialize(&self.pk_serializer);
849            let inequality_key = self
850                .inequality_key_desc
851                .as_ref()
852                .map(|desc| desc.serialize_inequal_key_from_row(&value.row));
853            entry
854                .remove(pk, inequality_key.as_ref())
855                .with_context(|| self.state.error_context(&value.row))?;
856        }
857
858        // If no cache maintained, only update the state table.
859        let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
860        self.state.table.delete(row);
861        let degree_state = self.degree_state.as_mut().expect("degree table missing");
862        degree_state.table.delete(degree);
863        Ok(())
864    }
865
866    /// Delete a row
867    /// Used when the side does not need to update degree.
868    pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
869        if let Some(mut entry) = self.inner.get_mut(key) {
870            let pk = (&value)
871                .project(&self.state.pk_indices)
872                .memcmp_serialize(&self.pk_serializer);
873
874            let inequality_key = self
875                .inequality_key_desc
876                .as_ref()
877                .map(|desc| desc.serialize_inequal_key_from_row(&value));
878            entry
879                .remove(pk, inequality_key.as_ref())
880                .with_context(|| self.state.error_context(&value))?;
881        }
882
883        // If no cache maintained, only update the state table.
884        self.state.table.delete(value);
885        Ok(())
886    }
887
888    /// Update a [`JoinEntryState`] into the hash table.
889    pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
890        self.inner.put(key.clone(), HashValueWrapper(Some(state)));
891    }
892
893    /// Evict the cache.
894    pub fn evict(&mut self) {
895        self.inner.evict();
896    }
897
898    /// Cached entry count for this hash table.
899    pub fn entry_count(&self) -> usize {
900        self.inner.len()
901    }
902
903    pub fn null_matched(&self) -> &K::Bitmap {
904        &self.null_matched
905    }
906
907    pub fn table_id(&self) -> u32 {
908        self.state.table.table_id()
909    }
910
911    pub fn join_key_data_types(&self) -> &[DataType] {
912        &self.join_key_data_types
913    }
914
915    /// Return true if the inequality key is null.
916    /// # Panics
917    /// Panics if the inequality key is not set.
918    pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
919        let desc = self
920            .inequality_key_desc
921            .as_ref()
922            .expect("inequality key desc missing");
923        row.datum_at(desc.idx).is_none()
924    }
925
926    /// Serialize the inequality key from a row.
927    /// # Panics
928    /// Panics if the inequality key is not set.
929    pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
930        self.inequality_key_desc
931            .as_ref()
932            .expect("inequality key desc missing")
933            .serialize_inequal_key_from_row(&row)
934    }
935
936    pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
937        row.project(&self.state.pk_indices)
938            .memcmp_serialize(&self.pk_serializer)
939    }
940}
941
942#[must_use]
943pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
944    state: StateTablePostCommit<'a, S>,
945    degree_state: Option<StateTablePostCommit<'a, S>>,
946    inner: &'a mut JoinHashMapInner<K, E>,
947}
948
949use risingwave_common_estimate_size::KvSize;
950use thiserror::Error;
951
952use super::*;
953use crate::executor::prelude::{Stream, try_stream};
954
955/// We manages a `HashMap` in memory for all entries belonging to a join key.
956/// When evicted, `cached` does not hold any entries.
957///
958/// If a `JoinEntryState` exists for a join key, the all records under this
959/// join key will be presented in the cache.
960#[derive(Default)]
961pub struct JoinEntryState<E: JoinEncoding> {
962    /// The full copy of the state.
963    cached: JoinRowSet<PkType, E::EncodedRow>,
964    /// Index used for AS OF join. The key is inequal column value. The value is the primary key in `cached`.
965    inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
966    kv_heap_size: KvSize,
967}
968
969impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
970    fn estimated_heap_size(&self) -> usize {
971        // TODO: Add btreemap internal size.
972        // https://github.com/risingwavelabs/risingwave/issues/9713
973        self.kv_heap_size.size()
974    }
975}
976
977#[derive(Error, Debug)]
978pub enum JoinEntryError {
979    #[error("double inserting a join state entry")]
980    Occupied,
981    #[error("removing a join state entry but it is not in the cache")]
982    Remove,
983    #[error("retrieving a pk from the inequality index but it is not in the cache")]
984    InequalIndex,
985}
986
987impl<E: JoinEncoding> JoinEntryState<E> {
988    /// Insert into the cache.
989    pub fn insert(
990        &mut self,
991        key: PkType,
992        value: E::EncodedRow,
993        inequality_key: Option<InequalKeyType>,
994    ) -> Result<&mut E::EncodedRow, JoinEntryError> {
995        let mut removed = false;
996        if !enable_strict_consistency() {
997            // strict consistency is off, let's remove existing (if any) first
998            if let Some(old_value) = self.cached.remove(&key) {
999                if let Some(inequality_key) = inequality_key.as_ref() {
1000                    self.remove_pk_from_inequality_index(&key, inequality_key);
1001                }
1002                self.kv_heap_size.sub(&key, &old_value);
1003                removed = true;
1004            }
1005        }
1006
1007        self.kv_heap_size.add(&key, &value);
1008
1009        if let Some(inequality_key) = inequality_key {
1010            self.insert_pk_to_inequality_index(key.clone(), inequality_key);
1011        }
1012        let ret = self.cached.try_insert(key.clone(), value);
1013
1014        if !enable_strict_consistency() {
1015            assert!(ret.is_ok(), "we have removed existing entry, if any");
1016            if removed {
1017                // if not silent, we should log the error
1018                consistency_error!(?key, "double inserting a join state entry");
1019            }
1020        }
1021
1022        ret.map_err(|_| JoinEntryError::Occupied)
1023    }
1024
1025    /// Delete from the cache.
1026    pub fn remove(
1027        &mut self,
1028        pk: PkType,
1029        inequality_key: Option<&InequalKeyType>,
1030    ) -> Result<(), JoinEntryError> {
1031        if let Some(value) = self.cached.remove(&pk) {
1032            self.kv_heap_size.sub(&pk, &value);
1033            if let Some(inequality_key) = inequality_key {
1034                self.remove_pk_from_inequality_index(&pk, inequality_key);
1035            }
1036            Ok(())
1037        } else if enable_strict_consistency() {
1038            Err(JoinEntryError::Remove)
1039        } else {
1040            consistency_error!(?pk, "removing a join state entry but it's not in the cache");
1041            Ok(())
1042        }
1043    }
1044
1045    fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
1046        if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
1047            if pk_set.remove(pk).is_none() {
1048                if enable_strict_consistency() {
1049                    panic!("removing a pk that it not in the inequality index");
1050                } else {
1051                    consistency_error!(?pk, "removing a pk that it not in the inequality index");
1052                };
1053            } else {
1054                self.kv_heap_size.sub(pk, &());
1055            }
1056            if pk_set.is_empty() {
1057                self.inequality_index.remove(inequality_key);
1058            }
1059        }
1060    }
1061
1062    fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
1063        if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
1064            let pk_size = pk.estimated_size();
1065            if pk_set.try_insert(pk, ()).is_err() {
1066                if enable_strict_consistency() {
1067                    panic!("inserting a pk that it already in the inequality index");
1068                } else {
1069                    consistency_error!("inserting a pk that it already in the inequality index");
1070                };
1071            } else {
1072                self.kv_heap_size.add_size(pk_size);
1073            }
1074        } else {
1075            let mut pk_set = JoinRowSet::default();
1076            pk_set.try_insert(pk, ()).expect("pk set should be empty");
1077            self.inequality_index
1078                .try_insert(inequality_key, pk_set)
1079                .expect("pk set should be empty");
1080        }
1081    }
1082
1083    pub fn get(
1084        &self,
1085        pk: &PkType,
1086        data_types: &[DataType],
1087    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1088        self.cached
1089            .get(pk)
1090            .map(|encoded| encoded.decode(data_types))
1091    }
1092
1093    /// Note: the first item in the tuple is the mutable reference to the value in this entry, while
1094    /// the second item is the decoded value. To mutate the degree, one **must not** forget to apply
1095    /// the changes to the first item.
1096    ///
1097    /// WARNING: Should not change the heap size of `StateValueType` with the mutable reference.
1098    pub fn values_mut<'a>(
1099        &'a mut self,
1100        data_types: &'a [DataType],
1101    ) -> impl Iterator<
1102        Item = (
1103            &'a mut E::EncodedRow,
1104            StreamExecutorResult<JoinRow<E::DecodedRow>>,
1105        ),
1106    > + 'a {
1107        self.cached.values_mut().map(|encoded| {
1108            let decoded = encoded.decode(data_types);
1109            (encoded, decoded)
1110        })
1111    }
1112
1113    pub fn len(&self) -> usize {
1114        self.cached.len()
1115    }
1116
1117    /// Range scan the cache using the inequality index.
1118    pub fn range_by_inequality<'a, R>(
1119        &'a self,
1120        range: R,
1121        data_types: &'a [DataType],
1122    ) -> impl Iterator<Item = StreamExecutorResult<JoinRow<E::DecodedRow>>> + 'a
1123    where
1124        R: RangeBounds<InequalKeyType> + 'a,
1125    {
1126        self.inequality_index.range(range).flat_map(|(_, pk_set)| {
1127            pk_set
1128                .keys()
1129                .flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
1130        })
1131    }
1132
1133    /// Get the records whose inequality key upper bound satisfy the given bound.
1134    pub fn upper_bound_by_inequality<'a>(
1135        &'a self,
1136        bound: Bound<&InequalKeyType>,
1137        data_types: &'a [DataType],
1138    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1139        if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
1140            if let Some(pk) = pk_set.first_key_sorted() {
1141                self.get_by_indexed_pk(pk, data_types)
1142            } else {
1143                panic!("pk set for a index record must has at least one element");
1144            }
1145        } else {
1146            None
1147        }
1148    }
1149
1150    pub fn get_by_indexed_pk(
1151        &self,
1152        pk: &PkType,
1153        data_types: &[DataType],
1154    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>>
1155where {
1156        if let Some(value) = self.cached.get(pk) {
1157            Some(value.decode(data_types))
1158        } else if enable_strict_consistency() {
1159            Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
1160        } else {
1161            consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
1162            None
1163        }
1164    }
1165
1166    /// Get the records whose inequality key lower bound satisfy the given bound.
1167    pub fn lower_bound_by_inequality<'a>(
1168        &'a self,
1169        bound: Bound<&InequalKeyType>,
1170        data_types: &'a [DataType],
1171    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1172        if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
1173            if let Some(pk) = pk_set.first_key_sorted() {
1174                self.get_by_indexed_pk(pk, data_types)
1175            } else {
1176                panic!("pk set for a index record must has at least one element");
1177            }
1178        } else {
1179            None
1180        }
1181    }
1182
1183    pub fn get_first_by_inequality<'a>(
1184        &'a self,
1185        inequality_key: &InequalKeyType,
1186        data_types: &'a [DataType],
1187    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
1188        if let Some(pk_set) = self.inequality_index.get(inequality_key) {
1189            if let Some(pk) = pk_set.first_key_sorted() {
1190                self.get_by_indexed_pk(pk, data_types)
1191            } else {
1192                panic!("pk set for a index record must has at least one element");
1193            }
1194        } else {
1195            None
1196        }
1197    }
1198
1199    pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
1200        &self.inequality_index
1201    }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206    use itertools::Itertools;
1207    use risingwave_common::array::*;
1208    use risingwave_common::types::ScalarRefImpl;
1209    use risingwave_common::util::iter_util::ZipEqDebug;
1210
1211    use super::*;
1212    use crate::executor::MemoryEncoding;
1213
1214    fn insert_chunk<E: JoinEncoding>(
1215        managed_state: &mut JoinEntryState<E>,
1216        pk_indices: &[usize],
1217        col_types: &[DataType],
1218        inequality_key_idx: Option<usize>,
1219        data_chunk: &DataChunk,
1220    ) {
1221        let pk_col_type = pk_indices
1222            .iter()
1223            .map(|idx| col_types[*idx].clone())
1224            .collect_vec();
1225        let pk_serializer =
1226            OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
1227        let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
1228        let inequality_key_serializer = inequality_key_type
1229            .map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
1230        for row_ref in data_chunk.rows() {
1231            let row: OwnedRow = row_ref.into_owned_row();
1232            let value_indices = (0..row.len() - 1).collect_vec();
1233            let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
1234            // Pk is only a `i64` here, so encoding method does not matter.
1235            let pk = OwnedRow::new(pk)
1236                .project(&value_indices)
1237                .memcmp_serialize(&pk_serializer);
1238            let inequality_key = inequality_key_idx.map(|idx| {
1239                (&row)
1240                    .project(&[idx])
1241                    .memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
1242            });
1243            let join_row = JoinRow { row, degree: 0 };
1244            managed_state
1245                .insert(pk, E::encode(&join_row), inequality_key)
1246                .unwrap();
1247        }
1248    }
1249
1250    fn check<E: JoinEncoding>(
1251        managed_state: &mut JoinEntryState<E>,
1252        col_types: &[DataType],
1253        col1: &[i64],
1254        col2: &[i64],
1255    ) {
1256        for ((_, matched_row), (d1, d2)) in managed_state
1257            .values_mut(col_types)
1258            .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
1259        {
1260            let matched_row = matched_row.unwrap();
1261            assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
1262            assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
1263            assert_eq!(matched_row.degree, 0);
1264        }
1265    }
1266
1267    #[tokio::test]
1268    async fn test_managed_join_state() {
1269        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1270        let col_types = vec![DataType::Int64, DataType::Int64];
1271        let pk_indices = [0];
1272
1273        let col1 = [3, 2, 1];
1274        let col2 = [4, 5, 6];
1275        let data_chunk1 = DataChunk::from_pretty(
1276            "I I
1277             3 4
1278             2 5
1279             1 6",
1280        );
1281
1282        // `Vec` in state
1283        insert_chunk::<MemoryEncoding>(
1284            &mut managed_state,
1285            &pk_indices,
1286            &col_types,
1287            None,
1288            &data_chunk1,
1289        );
1290        check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
1291
1292        // `BtreeMap` in state
1293        let col1 = [1, 2, 3, 4, 5];
1294        let col2 = [6, 5, 4, 9, 8];
1295        let data_chunk2 = DataChunk::from_pretty(
1296            "I I
1297             5 8
1298             4 9",
1299        );
1300        insert_chunk(
1301            &mut managed_state,
1302            &pk_indices,
1303            &col_types,
1304            None,
1305            &data_chunk2,
1306        );
1307        check(&mut managed_state, &col_types, &col1, &col2);
1308    }
1309
1310    #[tokio::test]
1311    async fn test_managed_join_state_w_inequality_index() {
1312        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
1313        let col_types = vec![DataType::Int64, DataType::Int64];
1314        let pk_indices = [0];
1315        let inequality_key_idx = Some(1);
1316        let inequality_key_serializer =
1317            OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
1318
1319        let col1 = [3, 2, 1];
1320        let col2 = [4, 5, 5];
1321        let data_chunk1 = DataChunk::from_pretty(
1322            "I I
1323             3 4
1324             2 5
1325             1 5",
1326        );
1327
1328        // `Vec` in state
1329        insert_chunk(
1330            &mut managed_state,
1331            &pk_indices,
1332            &col_types,
1333            inequality_key_idx,
1334            &data_chunk1,
1335        );
1336        check(&mut managed_state, &col_types, &col1, &col2);
1337        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
1338            .memcmp_serialize(&inequality_key_serializer);
1339        let row = managed_state
1340            .upper_bound_by_inequality(Bound::Included(&bound), &col_types)
1341            .unwrap()
1342            .unwrap();
1343        assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
1344        let row = managed_state
1345            .upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
1346            .unwrap()
1347            .unwrap();
1348        assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
1349
1350        // `BtreeMap` in state
1351        let col1 = [1, 2, 3, 4, 5];
1352        let col2 = [5, 5, 4, 4, 8];
1353        let data_chunk2 = DataChunk::from_pretty(
1354            "I I
1355             5 8
1356             4 4",
1357        );
1358        insert_chunk(
1359            &mut managed_state,
1360            &pk_indices,
1361            &col_types,
1362            inequality_key_idx,
1363            &data_chunk2,
1364        );
1365        check(&mut managed_state, &col_types, &col1, &col2);
1366
1367        let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
1368            .memcmp_serialize(&inequality_key_serializer);
1369        let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
1370        assert!(row.is_none());
1371    }
1372}