risingwave_stream/executor/join/
hash_join.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::{Bound, Deref, DerefMut};
16use std::sync::Arc;
17
18use anyhow::Context;
19use futures::StreamExt;
20use futures_async_stream::for_await;
21use join_row_set::JoinRowSet;
22use risingwave_common::bitmap::Bitmap;
23use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
24use risingwave_common::metrics::LabelGuardedIntCounter;
25use risingwave_common::row::{OwnedRow, Row, RowExt};
26use risingwave_common::types::{DataType, ScalarImpl};
27use risingwave_common::util::epoch::EpochPair;
28use risingwave_common::util::row_serde::OrderedRowSerde;
29use risingwave_common::util::sort_util::OrderType;
30use risingwave_common_estimate_size::EstimateSize;
31use risingwave_storage::StateStore;
32use risingwave_storage::store::PrefetchOptions;
33
34use super::row::{CachedJoinRow, DegreeType, build_degree_row};
35use crate::cache::ManagedLruCache;
36use crate::common::metrics::MetricsInfo;
37use crate::common::table::state_table::{StateTable, StateTablePostCommit};
38use crate::consistency::{consistency_error, enable_strict_consistency};
39use crate::executor::error::StreamExecutorResult;
40use crate::executor::join::row::JoinRow;
41use crate::executor::monitor::StreamingMetrics;
42use crate::executor::{JoinEncoding, StreamExecutorError};
43use crate::task::{ActorId, AtomicU64Ref, FragmentId};
44
45/// Memcomparable encoding.
46type PkType = Vec<u8>;
47pub type HashValueType<E> = Box<JoinEntryState<E>>;
48
49impl<E: JoinEncoding> EstimateSize for Box<JoinEntryState<E>> {
50    fn estimated_heap_size(&self) -> usize {
51        self.as_ref().estimated_heap_size()
52    }
53}
54
55/// The wrapper for [`JoinEntryState`] which should be `Some` most of the time in the hash table.
56///
57/// When the executor is operating on the specific entry of the map, it can hold the ownership of
58/// the entry by taking the value out of the `Option`, instead of holding a mutable reference to the
59/// map, which can make the compiler happy.
60struct HashValueWrapper<E: JoinEncoding>(Option<HashValueType<E>>);
61
62pub(crate) enum CacheResult<E: JoinEncoding> {
63    NeverMatch,            // Will never match, will not be in cache at all.
64    Miss,                  // Cache-miss
65    Hit(HashValueType<E>), // Cache-hit
66}
67
68impl<E: JoinEncoding> EstimateSize for HashValueWrapper<E> {
69    fn estimated_heap_size(&self) -> usize {
70        self.0.estimated_heap_size()
71    }
72}
73
74impl<E: JoinEncoding> HashValueWrapper<E> {
75    const MESSAGE: &'static str = "the state should always be `Some`";
76
77    /// Take the value out of the wrapper. Panic if the value is `None`.
78    pub fn take(&mut self) -> HashValueType<E> {
79        self.0.take().expect(Self::MESSAGE)
80    }
81}
82
83impl<E: JoinEncoding> Deref for HashValueWrapper<E> {
84    type Target = HashValueType<E>;
85
86    fn deref(&self) -> &Self::Target {
87        self.0.as_ref().expect(Self::MESSAGE)
88    }
89}
90
91impl<E: JoinEncoding> DerefMut for HashValueWrapper<E> {
92    fn deref_mut(&mut self) -> &mut Self::Target {
93        self.0.as_mut().expect(Self::MESSAGE)
94    }
95}
96
97type JoinHashMapInner<K, E> = ManagedLruCache<K, HashValueWrapper<E>, PrecomputedBuildHasher>;
98
99pub struct JoinHashMapMetrics {
100    /// Basic information
101    /// How many times have we hit the cache of join executor
102    lookup_miss_count: usize,
103    total_lookup_count: usize,
104    /// How many times have we miss the cache when insert row
105    insert_cache_miss_count: usize,
106
107    // Metrics
108    join_lookup_total_count_metric: LabelGuardedIntCounter,
109    join_lookup_miss_count_metric: LabelGuardedIntCounter,
110    join_insert_cache_miss_count_metrics: LabelGuardedIntCounter,
111}
112
113impl JoinHashMapMetrics {
114    pub fn new(
115        metrics: &StreamingMetrics,
116        actor_id: ActorId,
117        fragment_id: FragmentId,
118        side: &'static str,
119        join_table_id: TableId,
120    ) -> Self {
121        let actor_id = actor_id.to_string();
122        let fragment_id = fragment_id.to_string();
123        let join_table_id = join_table_id.to_string();
124        let join_lookup_total_count_metric = metrics
125            .join_lookup_total_count
126            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
127        let join_lookup_miss_count_metric = metrics
128            .join_lookup_miss_count
129            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
130        let join_insert_cache_miss_count_metrics = metrics
131            .join_insert_cache_miss_count
132            .with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
133
134        Self {
135            lookup_miss_count: 0,
136            total_lookup_count: 0,
137            insert_cache_miss_count: 0,
138            join_lookup_total_count_metric,
139            join_lookup_miss_count_metric,
140            join_insert_cache_miss_count_metrics,
141        }
142    }
143
144    pub fn inc_lookup(&mut self) {
145        self.total_lookup_count += 1;
146    }
147
148    pub fn inc_lookup_miss(&mut self) {
149        self.lookup_miss_count += 1;
150    }
151
152    pub fn inc_insert_cache_miss(&mut self) {
153        self.insert_cache_miss_count += 1;
154    }
155
156    pub fn flush(&mut self) {
157        self.join_lookup_total_count_metric
158            .inc_by(self.total_lookup_count as u64);
159        self.join_lookup_miss_count_metric
160            .inc_by(self.lookup_miss_count as u64);
161        self.join_insert_cache_miss_count_metrics
162            .inc_by(self.insert_cache_miss_count as u64);
163        self.total_lookup_count = 0;
164        self.lookup_miss_count = 0;
165        self.insert_cache_miss_count = 0;
166    }
167}
168
169pub struct JoinHashMap<K: HashKey, S: StateStore, E: JoinEncoding> {
170    /// Store the join states.
171    inner: JoinHashMapInner<K, E>,
172    /// Data types of the join key columns
173    join_key_data_types: Vec<DataType>,
174    /// Null safe bitmap for each join pair
175    null_matched: K::Bitmap,
176    /// The memcomparable serializer of primary key.
177    pk_serializer: OrderedRowSerde,
178    /// State table. Contains the data from upstream.
179    state: TableInner<S>,
180    /// Degree table.
181    ///
182    /// The degree is generated from the hash join executor.
183    /// Each row in `state` has a corresponding degree in `degree state`.
184    /// A degree value `d` in for a row means the row has `d` matched row in the other join side.
185    ///
186    /// It will only be used when needed in a side.
187    ///
188    /// - Full Outer: both side
189    /// - Left Outer/Semi/Anti: left side
190    /// - Right Outer/Semi/Anti: right side
191    /// - Inner: neither side.
192    ///
193    /// Should be set to `None` if `need_degree_table` was set to `false`.
194    ///
195    /// The degree of each row will tell us if we need to emit `NULL` for the row.
196    /// For instance, given `lhs LEFT JOIN rhs`,
197    /// If the degree of a row in `lhs` is 0, it means the row does not have a match in `rhs`.
198    /// If the degree of a row in `lhs` is 2, it means the row has two matches in `rhs`.
199    /// Now, when emitting the result of the join, we need to emit `NULL` for the row in `lhs` if
200    /// the degree is 0.
201    ///
202    /// Why don't just use a boolean value instead of a degree count?
203    /// Consider the case where we delete a matched record from `rhs`.
204    /// Since we can delete a record,
205    /// there must have been a record in `rhs` that matched the record in `lhs`.
206    /// So this value is `true`.
207    /// But we don't know how many records are matched after removing this record,
208    /// since we only stored a boolean value rather than the count.
209    /// Hence we need to store the count of matched records.
210    degree_state: Option<TableInner<S>>,
211    // TODO(kwannoel): Make this `const` instead.
212    /// If degree table is need
213    need_degree_table: bool,
214    /// Pk is part of the join key.
215    pk_contained_in_jk: bool,
216    /// Metrics of the hash map
217    metrics: JoinHashMapMetrics,
218    _marker: std::marker::PhantomData<E>,
219}
220
221impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
222    pub(crate) fn get_degree_state_mut_ref(&mut self) -> (&[usize], &mut Option<TableInner<S>>) {
223        (&self.state.order_key_indices, &mut self.degree_state)
224    }
225
226    /// NOTE(kwannoel): This allows us to concurrently stream records from the `state_table`,
227    /// and update the degree table, without using `unsafe` code.
228    ///
229    /// This is because we obtain separate references to separate parts of the `JoinHashMap`,
230    /// instead of reusing the same reference to `JoinHashMap` for concurrent read access to `state_table`,
231    /// and write access to the degree table.
232    pub(crate) async fn fetch_matched_rows_and_get_degree_table_ref<'a>(
233        &'a mut self,
234        key: &'a K,
235    ) -> StreamExecutorResult<(
236        impl Stream<Item = StreamExecutorResult<(PkType, JoinRow<OwnedRow>)>> + 'a,
237        &'a [usize],
238        &'a mut Option<TableInner<S>>,
239    )> {
240        let degree_state = &mut self.degree_state;
241        let (order_key_indices, pk_indices, state_table) = (
242            &self.state.order_key_indices,
243            &self.state.pk_indices,
244            &mut self.state.table,
245        );
246        let degrees = if let Some(degree_state) = degree_state {
247            Some(fetch_degrees(key, &self.join_key_data_types, &degree_state.table).await?)
248        } else {
249            None
250        };
251        let stream = into_stream(
252            &self.join_key_data_types,
253            pk_indices,
254            &self.pk_serializer,
255            state_table,
256            key,
257            degrees,
258        );
259        Ok((stream, order_key_indices, &mut self.degree_state))
260    }
261}
262
263#[try_stream(ok = (PkType, JoinRow<OwnedRow>), error = StreamExecutorError)]
264pub(crate) async fn into_stream<'a, K: HashKey, S: StateStore>(
265    join_key_data_types: &'a [DataType],
266    pk_indices: &'a [usize],
267    pk_serializer: &'a OrderedRowSerde,
268    state_table: &'a StateTable<S>,
269    key: &'a K,
270    degrees: Option<Vec<DegreeType>>,
271) {
272    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
273    let decoded_key = key.deserialize(join_key_data_types)?;
274    let table_iter = state_table
275        .iter_with_prefix_respecting_watermark(&decoded_key, sub_range, PrefetchOptions::default())
276        .await?;
277
278    #[for_await]
279    for (i, entry) in table_iter.enumerate() {
280        let encoded_row = entry?;
281        let encoded_pk = encoded_row
282            .as_ref()
283            .project(pk_indices)
284            .memcmp_serialize(pk_serializer);
285        let join_row = JoinRow::new(encoded_row, degrees.as_ref().map_or(0, |d| d[i]));
286        yield (encoded_pk, join_row);
287    }
288}
289
290/// We use this to fetch ALL degrees into memory.
291/// We use this instead of a streaming interface.
292/// It is necessary because we must update the `degree_state_table` concurrently.
293/// If we obtain the degrees in a stream,
294/// we will need to hold an immutable reference to the state table for the entire lifetime,
295/// preventing us from concurrently updating the state table.
296///
297/// The cost of fetching all degrees upfront is acceptable. We currently already do so
298/// in `fetch_cached_state`.
299/// The memory use should be limited since we only store a u64.
300///
301/// Let's say we have amplification of 1B, we will have 1B * 8 bytes ~= 8GB
302///
303/// We can also have further optimization, to permit breaking the streaming update,
304/// to flush the in-memory degrees, if this is proven to have high memory consumption.
305///
306/// TODO(kwannoel): Perhaps we can cache these separately from matched rows too.
307/// Because matched rows may occupy a larger capacity.
308///
309/// Argument for this:
310/// We only hit this when cache miss. When cache miss, we will have this as one off cost.
311/// Keeping this cached separately from matched rows is beneficial.
312/// Then we can evict matched rows, without touching the degrees.
313async fn fetch_degrees<K: HashKey, S: StateStore>(
314    key: &K,
315    join_key_data_types: &[DataType],
316    degree_state_table: &StateTable<S>,
317) -> StreamExecutorResult<Vec<DegreeType>> {
318    let key = key.deserialize(join_key_data_types)?;
319    let mut degrees = vec![];
320    let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
321    let table_iter = degree_state_table
322        .iter_with_prefix_respecting_watermark(key, sub_range, PrefetchOptions::default())
323        .await?;
324    let degree_col_idx = degree_col_idx_in_row(degree_state_table);
325    #[for_await]
326    for entry in table_iter {
327        let degree_row = entry?;
328        debug_assert!(
329            degree_row.len() > degree_col_idx,
330            "degree row should have at least pk_len + 1 columns"
331        );
332        let degree_i64 = degree_row
333            .datum_at(degree_col_idx)
334            .expect("degree should not be NULL");
335        degrees.push(degree_i64.into_int64() as u64);
336    }
337    Ok(degrees)
338}
339
340fn degree_col_idx_in_row<S: StateStore>(degree_state_table: &StateTable<S>) -> usize {
341    // Degree column is at index pk_len in the full schema: [pk..., _degree, inequality?].
342    let degree_col_idx = degree_state_table.pk_indices().len();
343    match degree_state_table.value_indices() {
344        Some(value_indices) => value_indices
345            .iter()
346            .position(|idx| *idx == degree_col_idx)
347            .expect("degree column should be included in value indices"),
348        None => degree_col_idx,
349    }
350}
351
352// NOTE(kwannoel): This is not really specific to `TableInner`.
353// A degree table is `TableInner`, a `TableInner` might not be a degree table.
354// Hence we don't specify it in its impl block.
355pub(crate) fn update_degree<S: StateStore, const INCREMENT: bool>(
356    order_key_indices: &[usize],
357    degree_state: &mut TableInner<S>,
358    matched_row: &mut JoinRow<impl Row>,
359) {
360    let inequality_idx = degree_state.degree_inequality_idx;
361    let old_degree_row = build_degree_row(
362        order_key_indices,
363        matched_row.degree,
364        inequality_idx,
365        &matched_row.row,
366    );
367    if INCREMENT {
368        matched_row.degree += 1;
369    } else {
370        // DECREMENT
371        matched_row.degree -= 1;
372    }
373    let new_degree_row = build_degree_row(
374        order_key_indices,
375        matched_row.degree,
376        inequality_idx,
377        &matched_row.row,
378    );
379    degree_state.table.update(old_degree_row, new_degree_row);
380}
381
382pub struct TableInner<S: StateStore> {
383    /// Indices of the (cache) pk in a state row
384    pub(crate) pk_indices: Vec<usize>,
385    /// Indices of the join key in a state row
386    join_key_indices: Vec<usize>,
387    /// The order key of the join side has the following format:
388    /// | `join_key` ... | pk ... |
389    /// Where `join_key` contains all the columns not in the pk.
390    /// It should be a superset of the pk.
391    order_key_indices: Vec<usize>,
392    /// Optional: index of inequality column in the input row for degree table.
393    /// Used for inequality-based watermark cleaning of degree tables.
394    /// When present, the degree table schema is: [pk..., _degree, `inequality_val`].
395    pub(crate) degree_inequality_idx: Option<usize>,
396    pub(crate) table: StateTable<S>,
397}
398
399impl<S: StateStore> TableInner<S> {
400    pub fn new(
401        pk_indices: Vec<usize>,
402        join_key_indices: Vec<usize>,
403        table: StateTable<S>,
404        degree_inequality_idx: Option<usize>,
405    ) -> Self {
406        let order_key_indices = table.pk_indices().to_vec();
407        Self {
408            pk_indices,
409            join_key_indices,
410            order_key_indices,
411            degree_inequality_idx,
412            table,
413        }
414    }
415
416    fn error_context(&self, row: &impl Row) -> String {
417        let pk = row.project(&self.pk_indices);
418        let jk = row.project(&self.join_key_indices);
419        format!(
420            "join key: {}, pk: {}, row: {}, state_table_id: {}",
421            jk.display(),
422            pk.display(),
423            row.display(),
424            self.table.table_id()
425        )
426    }
427}
428
429impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
430    /// Create a [`JoinHashMap`] with the given LRU capacity.
431    #[expect(clippy::too_many_arguments)]
432    pub fn new(
433        watermark_sequence: AtomicU64Ref,
434        join_key_data_types: Vec<DataType>,
435        state_join_key_indices: Vec<usize>,
436        state_all_data_types: Vec<DataType>,
437        state_table: StateTable<S>,
438        state_pk_indices: Vec<usize>,
439        degree_state: Option<TableInner<S>>,
440        null_matched: K::Bitmap,
441        pk_contained_in_jk: bool,
442        metrics: Arc<StreamingMetrics>,
443        actor_id: ActorId,
444        fragment_id: FragmentId,
445        side: &'static str,
446    ) -> Self {
447        // TODO: unify pk encoding with state table.
448        let pk_data_types = state_pk_indices
449            .iter()
450            .map(|i| state_all_data_types[*i].clone())
451            .collect();
452        let pk_serializer = OrderedRowSerde::new(
453            pk_data_types,
454            vec![OrderType::ascending(); state_pk_indices.len()],
455        );
456
457        let join_table_id = state_table.table_id();
458        let state = TableInner {
459            pk_indices: state_pk_indices,
460            join_key_indices: state_join_key_indices,
461            order_key_indices: state_table.pk_indices().to_vec(),
462            degree_inequality_idx: None,
463            table: state_table,
464        };
465
466        let need_degree_table = degree_state.is_some();
467
468        let metrics_info = MetricsInfo::new(
469            metrics.clone(),
470            join_table_id,
471            actor_id,
472            format!("hash join {}", side),
473        );
474
475        let cache = ManagedLruCache::unbounded_with_hasher(
476            watermark_sequence,
477            metrics_info,
478            PrecomputedBuildHasher,
479        );
480
481        Self {
482            inner: cache,
483            join_key_data_types,
484            null_matched,
485            pk_serializer,
486            state,
487            degree_state,
488            need_degree_table,
489            pk_contained_in_jk,
490            metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
491            _marker: std::marker::PhantomData,
492        }
493    }
494
495    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
496        self.state.table.init_epoch(epoch).await?;
497        if let Some(degree_state) = &mut self.degree_state {
498            degree_state.table.init_epoch(epoch).await?;
499        }
500        Ok(())
501    }
502}
503
504impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMapPostCommit<'_, K, S, E> {
505    pub async fn post_yield_barrier(
506        self,
507        vnode_bitmap: Option<Arc<Bitmap>>,
508    ) -> StreamExecutorResult<Option<bool>> {
509        let cache_may_stale = self.state.post_yield_barrier(vnode_bitmap.clone()).await?;
510        if let Some(degree_state) = self.degree_state {
511            let _ = degree_state.post_yield_barrier(vnode_bitmap).await?;
512        }
513        let cache_may_stale = cache_may_stale.map(|(_, cache_may_stale)| cache_may_stale);
514        if cache_may_stale.unwrap_or(false) {
515            self.inner.clear();
516        }
517        Ok(cache_may_stale)
518    }
519}
520impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinHashMap<K, S, E> {
521    pub fn update_watermark(&mut self, watermark: ScalarImpl) {
522        // TODO: remove data in cache.
523        self.state.table.update_watermark(watermark.clone());
524        if let Some(degree_state) = &mut self.degree_state {
525            degree_state.table.update_watermark(watermark);
526        }
527    }
528
529    /// Take the state for the given `key` out of the hash table and return it. One **MUST** call
530    /// `update_state` after some operations to put the state back.
531    ///
532    /// If the state does not exist in the cache, fetch the remote storage and return. If it still
533    /// does not exist in the remote storage, a [`JoinEntryState`] with empty cache will be
534    /// returned.
535    ///
536    /// Note: This will NOT remove anything from remote storage.
537    pub fn take_state_opt(&mut self, key: &K) -> CacheResult<E> {
538        self.metrics.total_lookup_count += 1;
539        if self.inner.contains(key) {
540            tracing::trace!("hit cache for join key: {:?}", key);
541            // Do not update the LRU statistics here with `peek_mut` since we will put the state
542            // back.
543            let mut state = self.inner.peek_mut(key).expect("checked contains");
544            CacheResult::Hit(state.take())
545        } else {
546            self.metrics.lookup_miss_count += 1;
547            tracing::trace!("miss cache for join key: {:?}", key);
548            CacheResult::Miss
549        }
550    }
551
552    pub async fn flush(
553        &mut self,
554        epoch: EpochPair,
555    ) -> StreamExecutorResult<JoinHashMapPostCommit<'_, K, S, E>> {
556        self.metrics.flush();
557        let state_post_commit = self.state.table.commit(epoch).await?;
558        let degree_state_post_commit = if let Some(degree_state) = &mut self.degree_state {
559            Some(degree_state.table.commit(epoch).await?)
560        } else {
561            None
562        };
563        Ok(JoinHashMapPostCommit {
564            state: state_post_commit,
565            degree_state: degree_state_post_commit,
566            inner: &mut self.inner,
567        })
568    }
569
570    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
571        self.state.table.try_flush().await?;
572        if let Some(degree_state) = &mut self.degree_state {
573            degree_state.table.try_flush().await?;
574        }
575        Ok(())
576    }
577
578    pub fn insert_handle_degree(
579        &mut self,
580        key: &K,
581        value: JoinRow<impl Row>,
582    ) -> StreamExecutorResult<()> {
583        if self.need_degree_table {
584            self.insert(key, value)
585        } else {
586            self.insert_row(key, value.row)
587        }
588    }
589
590    /// Insert a join row
591    pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
592        let pk = self.serialize_pk_from_row(&value.row);
593
594        // TODO(yuhao): avoid this `contains`.
595        // https://github.com/risingwavelabs/risingwave/issues/9233
596        if self.inner.contains(key) {
597            // Update cache
598            let mut entry = self.inner.get_mut(key).expect("checked contains");
599            entry
600                .insert(pk, E::encode(&value))
601                .with_context(|| self.state.error_context(&value.row))?;
602        } else if self.pk_contained_in_jk {
603            // Refill cache when the join key exist in neither cache or storage.
604            self.metrics.insert_cache_miss_count += 1;
605            let mut entry: JoinEntryState<E> = JoinEntryState::default();
606            entry
607                .insert(pk, E::encode(&value))
608                .with_context(|| self.state.error_context(&value.row))?;
609            self.update_state(key, entry.into());
610        }
611
612        // Update the flush buffer.
613        if let Some(degree_state) = self.degree_state.as_mut() {
614            let (row, degree) = value.to_table_rows(
615                &self.state.order_key_indices,
616                degree_state.degree_inequality_idx,
617            );
618            self.state.table.insert(row);
619            degree_state.table.insert(degree);
620        } else {
621            self.state.table.insert(value.row);
622        }
623        Ok(())
624    }
625
626    /// Insert a row.
627    /// Used when the side does not need to update degree.
628    pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
629        let join_row = JoinRow::new(&value, 0);
630        self.insert(key, join_row)?;
631        Ok(())
632    }
633
634    pub fn delete_row_in_mem(&mut self, key: &K, value: &impl Row) -> StreamExecutorResult<()> {
635        if let Some(mut entry) = self.inner.get_mut(key) {
636            let pk = (&value)
637                .project(&self.state.pk_indices)
638                .memcmp_serialize(&self.pk_serializer);
639            entry
640                .remove(pk)
641                .with_context(|| self.state.error_context(&value))?;
642        }
643        Ok(())
644    }
645
646    pub fn delete_handle_degree(
647        &mut self,
648        key: &K,
649        value: JoinRow<impl Row>,
650    ) -> StreamExecutorResult<()> {
651        if self.need_degree_table {
652            self.delete(key, value)
653        } else {
654            self.delete_row(key, value.row)
655        }
656    }
657
658    /// Delete a join row
659    pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
660        self.delete_row_in_mem(key, &value.row)?;
661
662        // If no cache maintained, only update the state table.
663        let degree_state = self.degree_state.as_mut().expect("degree table missing");
664        let (row, degree) = value.to_table_rows(
665            &self.state.order_key_indices,
666            degree_state.degree_inequality_idx,
667        );
668        self.state.table.delete(row);
669        degree_state.table.delete(degree);
670        Ok(())
671    }
672
673    /// Delete a row
674    /// Used when the side does not need to update degree.
675    pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
676        self.delete_row_in_mem(key, &value)?;
677
678        // If no cache maintained, only update the state table.
679        self.state.table.delete(value);
680        Ok(())
681    }
682
683    /// Update a [`JoinEntryState`] into the hash table.
684    pub fn update_state(&mut self, key: &K, state: HashValueType<E>) {
685        self.inner.put(key.clone(), HashValueWrapper(Some(state)));
686    }
687
688    /// Evict the cache.
689    pub fn evict(&mut self) {
690        self.inner.evict();
691    }
692
693    /// Cached entry count for this hash table.
694    pub fn entry_count(&self) -> usize {
695        self.inner.len()
696    }
697
698    pub fn null_matched(&self) -> &K::Bitmap {
699        &self.null_matched
700    }
701
702    pub fn table_id(&self) -> TableId {
703        self.state.table.table_id()
704    }
705
706    pub fn join_key_data_types(&self) -> &[DataType] {
707        &self.join_key_data_types
708    }
709
710    pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
711        row.project(&self.state.pk_indices)
712            .memcmp_serialize(&self.pk_serializer)
713    }
714}
715
716#[must_use]
717pub struct JoinHashMapPostCommit<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
718    state: StateTablePostCommit<'a, S>,
719    degree_state: Option<StateTablePostCommit<'a, S>>,
720    inner: &'a mut JoinHashMapInner<K, E>,
721}
722
723use risingwave_common::catalog::TableId;
724use risingwave_common_estimate_size::KvSize;
725use thiserror::Error;
726
727use super::*;
728use crate::executor::prelude::{Stream, try_stream};
729
730/// We manages a `HashMap` in memory for all entries belonging to a join key.
731/// When evicted, `cached` does not hold any entries.
732///
733/// If a `JoinEntryState` exists for a join key, the all records under this
734/// join key will be presented in the cache.
735#[derive(Default)]
736pub struct JoinEntryState<E: JoinEncoding> {
737    /// The full copy of the state.
738    cached: JoinRowSet<PkType, E::EncodedRow>,
739    kv_heap_size: KvSize,
740}
741
742impl<E: JoinEncoding> EstimateSize for JoinEntryState<E> {
743    fn estimated_heap_size(&self) -> usize {
744        // TODO: Add btreemap internal size.
745        // https://github.com/risingwavelabs/risingwave/issues/9713
746        self.kv_heap_size.size()
747    }
748}
749
750#[derive(Error, Debug)]
751pub enum JoinEntryError {
752    #[error("double inserting a join state entry")]
753    Occupied,
754    #[error("removing a join state entry but it is not in the cache")]
755    Remove,
756}
757
758impl<E: JoinEncoding> JoinEntryState<E> {
759    /// Insert into the cache.
760    pub fn insert(
761        &mut self,
762        key: PkType,
763        value: E::EncodedRow,
764    ) -> Result<&mut E::EncodedRow, JoinEntryError> {
765        let mut removed = false;
766        if !enable_strict_consistency() {
767            // strict consistency is off, let's remove existing (if any) first
768            if let Some(old_value) = self.cached.remove(&key) {
769                self.kv_heap_size.sub(&key, &old_value);
770                removed = true;
771            }
772        }
773
774        self.kv_heap_size.add(&key, &value);
775
776        let ret = self.cached.try_insert(key.clone(), value);
777
778        if !enable_strict_consistency() {
779            assert!(ret.is_ok(), "we have removed existing entry, if any");
780            if removed {
781                // if not silent, we should log the error
782                consistency_error!(?key, "double inserting a join state entry");
783            }
784        }
785
786        ret.map_err(|_| JoinEntryError::Occupied)
787    }
788
789    /// Delete from the cache.
790    pub fn remove(&mut self, pk: PkType) -> Result<(), JoinEntryError> {
791        if let Some(value) = self.cached.remove(&pk) {
792            self.kv_heap_size.sub(&pk, &value);
793            Ok(())
794        } else if enable_strict_consistency() {
795            Err(JoinEntryError::Remove)
796        } else {
797            consistency_error!(?pk, "removing a join state entry but it's not in the cache");
798            Ok(())
799        }
800    }
801
802    pub fn get(
803        &self,
804        pk: &PkType,
805        data_types: &[DataType],
806    ) -> Option<StreamExecutorResult<JoinRow<E::DecodedRow>>> {
807        self.cached
808            .get(pk)
809            .map(|encoded| encoded.decode(data_types))
810    }
811
812    /// Note: the first item in the tuple is the mutable reference to the value in this entry, while
813    /// the second item is the decoded value. To mutate the degree, one **must not** forget to apply
814    /// the changes to the first item.
815    ///
816    /// WARNING: Should not change the heap size of `StateValueType` with the mutable reference.
817    pub fn values_mut<'a>(
818        &'a mut self,
819        data_types: &'a [DataType],
820    ) -> impl Iterator<
821        Item = (
822            &'a mut E::EncodedRow,
823            StreamExecutorResult<JoinRow<E::DecodedRow>>,
824        ),
825    > + 'a {
826        self.cached.values_mut().map(|encoded| {
827            let decoded = encoded.decode(data_types);
828            (encoded, decoded)
829        })
830    }
831
832    pub fn len(&self) -> usize {
833        self.cached.len()
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use itertools::Itertools;
840    use risingwave_common::array::*;
841    use risingwave_common::types::ScalarRefImpl;
842    use risingwave_common::util::iter_util::ZipEqDebug;
843
844    use super::*;
845    use crate::executor::MemoryEncoding;
846
847    fn insert_chunk<E: JoinEncoding>(
848        managed_state: &mut JoinEntryState<E>,
849        pk_indices: &[usize],
850        col_types: &[DataType],
851        data_chunk: &DataChunk,
852    ) {
853        let pk_col_type = pk_indices
854            .iter()
855            .map(|idx| col_types[*idx].clone())
856            .collect_vec();
857        let pk_serializer =
858            OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
859        for row_ref in data_chunk.rows() {
860            let row: OwnedRow = row_ref.into_owned_row();
861            let value_indices = (0..row.len() - 1).collect_vec();
862            let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
863            // Pk is only a `i64` here, so encoding method does not matter.
864            let pk = OwnedRow::new(pk)
865                .project(&value_indices)
866                .memcmp_serialize(&pk_serializer);
867            let join_row = JoinRow { row, degree: 0 };
868            managed_state.insert(pk, E::encode(&join_row)).unwrap();
869        }
870    }
871
872    fn check<E: JoinEncoding>(
873        managed_state: &mut JoinEntryState<E>,
874        col_types: &[DataType],
875        col1: &[i64],
876        col2: &[i64],
877    ) {
878        for ((_, matched_row), (d1, d2)) in managed_state
879            .values_mut(col_types)
880            .zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
881        {
882            let matched_row = matched_row.unwrap();
883            assert_eq!(matched_row.row.datum_at(0), Some(ScalarRefImpl::Int64(*d1)));
884            assert_eq!(matched_row.row.datum_at(1), Some(ScalarRefImpl::Int64(*d2)));
885            assert_eq!(matched_row.degree, 0);
886        }
887    }
888
889    #[tokio::test]
890    async fn test_managed_join_state() {
891        let mut managed_state: JoinEntryState<MemoryEncoding> = JoinEntryState::default();
892        let col_types = vec![DataType::Int64, DataType::Int64];
893        let pk_indices = [0];
894
895        let col1 = [3, 2, 1];
896        let col2 = [4, 5, 6];
897        let data_chunk1 = DataChunk::from_pretty(
898            "I I
899             3 4
900             2 5
901             1 6",
902        );
903
904        // `Vec` in state
905        insert_chunk::<MemoryEncoding>(&mut managed_state, &pk_indices, &col_types, &data_chunk1);
906        check::<MemoryEncoding>(&mut managed_state, &col_types, &col1, &col2);
907
908        // `BtreeMap` in state
909        let col1 = [1, 2, 3, 4, 5];
910        let col2 = [6, 5, 4, 9, 8];
911        let data_chunk2 = DataChunk::from_pretty(
912            "I I
913             5 8
914             4 9",
915        );
916        insert_chunk(&mut managed_state, &pk_indices, &col_types, &data_chunk2);
917        check(&mut managed_state, &col_types, &col1, &col2);
918    }
919}