use std::alloc::Global;
use std::cmp::Ordering;
use std::ops::{Bound, Deref, DerefMut, RangeBounds};
use std::sync::Arc;
use anyhow::{anyhow, Context};
use futures::future::{join, try_join};
use futures::{pin_mut, stream, StreamExt};
use futures_async_stream::for_await;
use join_row_set::JoinRowSet;
use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
use risingwave_common::bitmap::Bitmap;
use risingwave_common::hash::{HashKey, PrecomputedBuildHasher};
use risingwave_common::metrics::LabelGuardedIntCounter;
use risingwave_common::row::{OwnedRow, Row, RowExt};
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::util::epoch::EpochPair;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::util::row_serde::OrderedRowSerde;
use risingwave_common::util::sort_util::OrderType;
use risingwave_common_estimate_size::EstimateSize;
use risingwave_storage::store::PrefetchOptions;
use risingwave_storage::StateStore;
use thiserror_ext::AsReport;
use super::row::{DegreeType, EncodedJoinRow};
use crate::cache::ManagedLruCache;
use crate::common::metrics::MetricsInfo;
use crate::common::table::state_table::StateTable;
use crate::consistency::{consistency_error, consistency_panic, enable_strict_consistency};
use crate::executor::error::StreamExecutorResult;
use crate::executor::join::row::JoinRow;
use crate::executor::monitor::StreamingMetrics;
use crate::task::{ActorId, AtomicU64Ref, FragmentId};
type PkType = Vec<u8>;
type InequalKeyType = Vec<u8>;
pub type StateValueType = EncodedJoinRow;
pub type HashValueType = Box<JoinEntryState>;
impl EstimateSize for HashValueType {
fn estimated_heap_size(&self) -> usize {
self.as_ref().estimated_heap_size()
}
}
struct HashValueWrapper(Option<HashValueType>);
impl EstimateSize for HashValueWrapper {
fn estimated_heap_size(&self) -> usize {
self.0.estimated_heap_size()
}
}
impl HashValueWrapper {
const MESSAGE: &'static str = "the state should always be `Some`";
pub fn take(&mut self) -> HashValueType {
self.0.take().expect(Self::MESSAGE)
}
}
impl Deref for HashValueWrapper {
type Target = HashValueType;
fn deref(&self) -> &Self::Target {
self.0.as_ref().expect(Self::MESSAGE)
}
}
impl DerefMut for HashValueWrapper {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut().expect(Self::MESSAGE)
}
}
type JoinHashMapInner<K> =
ManagedLruCache<K, HashValueWrapper, PrecomputedBuildHasher, SharedStatsAlloc<Global>>;
pub struct JoinHashMapMetrics {
lookup_miss_count: usize,
total_lookup_count: usize,
insert_cache_miss_count: usize,
join_lookup_total_count_metric: LabelGuardedIntCounter<4>,
join_lookup_miss_count_metric: LabelGuardedIntCounter<4>,
join_insert_cache_miss_count_metrics: LabelGuardedIntCounter<4>,
}
impl JoinHashMapMetrics {
pub fn new(
metrics: &StreamingMetrics,
actor_id: ActorId,
fragment_id: FragmentId,
side: &'static str,
join_table_id: u32,
) -> Self {
let actor_id = actor_id.to_string();
let fragment_id = fragment_id.to_string();
let join_table_id = join_table_id.to_string();
let join_lookup_total_count_metric = metrics
.join_lookup_total_count
.with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
let join_lookup_miss_count_metric = metrics
.join_lookup_miss_count
.with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
let join_insert_cache_miss_count_metrics = metrics
.join_insert_cache_miss_count
.with_guarded_label_values(&[(side), &join_table_id, &actor_id, &fragment_id]);
Self {
lookup_miss_count: 0,
total_lookup_count: 0,
insert_cache_miss_count: 0,
join_lookup_total_count_metric,
join_lookup_miss_count_metric,
join_insert_cache_miss_count_metrics,
}
}
pub fn flush(&mut self) {
self.join_lookup_total_count_metric
.inc_by(self.total_lookup_count as u64);
self.join_lookup_miss_count_metric
.inc_by(self.lookup_miss_count as u64);
self.join_insert_cache_miss_count_metrics
.inc_by(self.insert_cache_miss_count as u64);
self.total_lookup_count = 0;
self.lookup_miss_count = 0;
self.insert_cache_miss_count = 0;
}
}
struct InequalityKeyDesc {
idx: usize,
serializer: OrderedRowSerde,
}
impl InequalityKeyDesc {
pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
let indices = vec![self.idx];
let inequality_key = row.project(&indices);
inequality_key.memcmp_serialize(&self.serializer)
}
}
pub struct JoinHashMap<K: HashKey, S: StateStore> {
inner: JoinHashMapInner<K>,
join_key_data_types: Vec<DataType>,
null_matched: K::Bitmap,
pk_serializer: OrderedRowSerde,
state: TableInner<S>,
degree_state: Option<TableInner<S>>,
need_degree_table: bool,
pk_contained_in_jk: bool,
inequality_key_desc: Option<InequalityKeyDesc>,
metrics: JoinHashMapMetrics,
}
pub struct TableInner<S: StateStore> {
pk_indices: Vec<usize>,
join_key_indices: Vec<usize>,
order_key_indices: Vec<usize>,
pub(crate) table: StateTable<S>,
}
impl<S: StateStore> TableInner<S> {
pub fn new(pk_indices: Vec<usize>, join_key_indices: Vec<usize>, table: StateTable<S>) -> Self {
let order_key_indices = table.pk_indices().to_vec();
Self {
pk_indices,
join_key_indices,
order_key_indices,
table,
}
}
fn error_context(&self, row: &impl Row) -> String {
let pk = row.project(&self.pk_indices);
let jk = row.project(&self.join_key_indices);
format!(
"join key: {}, pk: {}, row: {}, state_table_id: {}",
jk.display(),
pk.display(),
row.display(),
self.table.table_id()
)
}
}
impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
#[allow(clippy::too_many_arguments)]
pub fn new(
watermark_sequence: AtomicU64Ref,
join_key_data_types: Vec<DataType>,
state_join_key_indices: Vec<usize>,
state_all_data_types: Vec<DataType>,
state_table: StateTable<S>,
state_pk_indices: Vec<usize>,
degree_state: Option<TableInner<S>>,
null_matched: K::Bitmap,
pk_contained_in_jk: bool,
inequality_key_idx: Option<usize>,
metrics: Arc<StreamingMetrics>,
actor_id: ActorId,
fragment_id: FragmentId,
side: &'static str,
) -> Self {
let alloc = StatsAlloc::new(Global).shared();
let pk_data_types = state_pk_indices
.iter()
.map(|i| state_all_data_types[*i].clone())
.collect();
let pk_serializer = OrderedRowSerde::new(
pk_data_types,
vec![OrderType::ascending(); state_pk_indices.len()],
);
let inequality_key_desc = inequality_key_idx.map(|idx| {
let serializer = OrderedRowSerde::new(
vec![state_all_data_types[idx].clone()],
vec![OrderType::ascending()],
);
InequalityKeyDesc { idx, serializer }
});
let join_table_id = state_table.table_id();
let state = TableInner {
pk_indices: state_pk_indices,
join_key_indices: state_join_key_indices,
order_key_indices: state_table.pk_indices().to_vec(),
table: state_table,
};
let need_degree_table = degree_state.is_some();
let metrics_info = MetricsInfo::new(
metrics.clone(),
join_table_id,
actor_id,
format!("hash join {}", side),
);
let cache = ManagedLruCache::unbounded_with_hasher_in(
watermark_sequence,
metrics_info,
PrecomputedBuildHasher,
alloc,
);
Self {
inner: cache,
join_key_data_types,
null_matched,
pk_serializer,
state,
degree_state,
need_degree_table,
pk_contained_in_jk,
inequality_key_desc,
metrics: JoinHashMapMetrics::new(&metrics, actor_id, fragment_id, side, join_table_id),
}
}
pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
self.state.table.init_epoch(epoch).await?;
if let Some(degree_state) = &mut self.degree_state {
degree_state.table.init_epoch(epoch).await?;
}
Ok(())
}
pub fn update_vnode_bitmap(&mut self, vnode_bitmap: Arc<Bitmap>) -> bool {
let (_previous_vnode_bitmap, cache_may_stale) =
self.state.table.update_vnode_bitmap(vnode_bitmap.clone());
let _ = self
.degree_state
.as_mut()
.map(|degree_state| degree_state.table.update_vnode_bitmap(vnode_bitmap.clone()));
if cache_may_stale {
self.inner.clear();
}
cache_may_stale
}
pub fn update_watermark(&mut self, watermark: ScalarImpl) {
self.state.table.update_watermark(watermark.clone());
if let Some(degree_state) = &mut self.degree_state {
degree_state.table.update_watermark(watermark);
}
}
pub async fn take_state<'a>(&mut self, key: &K) -> StreamExecutorResult<HashValueType> {
self.metrics.total_lookup_count += 1;
let state = if self.inner.contains(key) {
let mut state = self.inner.peek_mut(key).unwrap();
state.take()
} else {
self.metrics.lookup_miss_count += 1;
self.fetch_cached_state(key).await?.into()
};
Ok(state)
}
async fn fetch_cached_state(&self, key: &K) -> StreamExecutorResult<JoinEntryState> {
let key = key.deserialize(&self.join_key_data_types)?;
let mut entry_state = JoinEntryState::default();
if self.need_degree_table {
let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
&(Bound::Unbounded, Bound::Unbounded);
let table_iter_fut =
self.state
.table
.iter_with_prefix(&key, sub_range, PrefetchOptions::default());
let degree_state = self.degree_state.as_ref().unwrap();
let degree_table_iter_fut =
degree_state
.table
.iter_with_prefix(&key, sub_range, PrefetchOptions::default());
let (table_iter, degree_table_iter) =
try_join(table_iter_fut, degree_table_iter_fut).await?;
let mut pinned_table_iter = std::pin::pin!(table_iter);
let mut pinned_degree_table_iter = std::pin::pin!(degree_table_iter);
let mut rows = vec![];
let mut degree_rows = vec![];
let mut inconsistency_happened = false;
loop {
let (row, degree_row) =
join(pinned_table_iter.next(), pinned_degree_table_iter.next()).await;
let (row, degree_row) = match (row, degree_row) {
(None, None) => break,
(None, Some(_)) => {
inconsistency_happened = true;
consistency_panic!(
"mismatched row and degree table of join key: {:?}, degree table has more rows",
&key
);
break;
}
(Some(_), None) => {
inconsistency_happened = true;
consistency_panic!(
"mismatched row and degree table of join key: {:?}, input table has more rows",
&key
);
break;
}
(Some(r), Some(d)) => (r, d),
};
let row = row?;
let degree_row = degree_row?;
rows.push(row);
degree_rows.push(degree_row);
}
if inconsistency_happened {
assert_ne!(rows.len(), degree_rows.len());
let row_iter = stream::iter(rows.into_iter()).peekable();
let degree_row_iter = stream::iter(degree_rows.into_iter()).peekable();
pin_mut!(row_iter);
pin_mut!(degree_row_iter);
loop {
match join(row_iter.as_mut().peek(), degree_row_iter.as_mut().peek()).await {
(None, _) | (_, None) => break,
(Some(row), Some(degree_row)) => match row.key().cmp(degree_row.key()) {
Ordering::Greater => {
degree_row_iter.next().await;
}
Ordering::Less => {
row_iter.next().await;
}
Ordering::Equal => {
let row = row_iter.next().await.unwrap();
let degree_row = degree_row_iter.next().await.unwrap();
let pk = row
.as_ref()
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
let degree_i64 = degree_row
.datum_at(degree_row.len() - 1)
.expect("degree should not be NULL");
let inequality_key = self
.inequality_key_desc
.as_ref()
.map(|desc| desc.serialize_inequal_key_from_row(row.row()));
entry_state
.insert(
pk,
JoinRow::new(row.row(), degree_i64.into_int64() as u64)
.encode(),
inequality_key,
)
.with_context(|| self.state.error_context(row.row()))?;
}
},
}
}
} else {
assert_eq!(rows.len(), degree_rows.len());
#[for_await]
for (row, degree_row) in
stream::iter(rows.into_iter().zip_eq_fast(degree_rows.into_iter()))
{
let pk1 = row.key();
let pk2 = degree_row.key();
debug_assert_eq!(
pk1, pk2,
"mismatched pk in degree table: pk1: {pk1:?}, pk2: {pk2:?}",
);
let pk = row
.as_ref()
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
let inequality_key = self
.inequality_key_desc
.as_ref()
.map(|desc| desc.serialize_inequal_key_from_row(row.row()));
let degree_i64 = degree_row
.datum_at(degree_row.len() - 1)
.expect("degree should not be NULL");
entry_state
.insert(
pk,
JoinRow::new(row.row(), degree_i64.into_int64() as u64).encode(),
inequality_key,
)
.with_context(|| self.state.error_context(row.row()))?;
}
}
} else {
let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
&(Bound::Unbounded, Bound::Unbounded);
let table_iter = self
.state
.table
.iter_with_prefix(&key, sub_range, PrefetchOptions::default())
.await?;
#[for_await]
for entry in table_iter {
let row = entry?;
let pk = row
.as_ref()
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
let inequality_key = self
.inequality_key_desc
.as_ref()
.map(|desc| desc.serialize_inequal_key_from_row(row.row()));
entry_state
.insert(pk, JoinRow::new(row.row(), 0).encode(), inequality_key)
.with_context(|| self.state.error_context(row.row()))?;
}
};
Ok(entry_state)
}
pub async fn flush(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
self.metrics.flush();
self.state.table.commit(epoch).await?;
if let Some(degree_state) = &mut self.degree_state {
degree_state.table.commit(epoch).await?;
}
Ok(())
}
pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
self.state.table.try_flush().await?;
if let Some(degree_state) = &mut self.degree_state {
degree_state.table.try_flush().await?;
}
Ok(())
}
pub fn insert(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
let pk = self.serialize_pk_from_row(&value.row);
let inequality_key = self
.inequality_key_desc
.as_ref()
.map(|desc| desc.serialize_inequal_key_from_row(&value.row));
if self.inner.contains(key) {
let mut entry = self.inner.get_mut(key).unwrap();
entry
.insert(pk, value.encode(), inequality_key)
.with_context(|| self.state.error_context(&value.row))?;
} else if self.pk_contained_in_jk {
self.metrics.insert_cache_miss_count += 1;
let mut state = JoinEntryState::default();
state
.insert(pk, value.encode(), inequality_key)
.with_context(|| self.state.error_context(&value.row))?;
self.update_state(key, state.into());
}
if let Some(degree_state) = self.degree_state.as_mut() {
let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
self.state.table.insert(row);
degree_state.table.insert(degree);
} else {
self.state.table.insert(value.row);
}
Ok(())
}
pub fn insert_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
let join_row = JoinRow::new(&value, 0);
self.insert(key, join_row)?;
Ok(())
}
pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
if let Some(mut entry) = self.inner.get_mut(key) {
let pk = (&value.row)
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
let inequality_key = self
.inequality_key_desc
.as_ref()
.map(|desc| desc.serialize_inequal_key_from_row(&value.row));
entry
.remove(pk, inequality_key.as_ref())
.with_context(|| self.state.error_context(&value.row))?;
}
let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
self.state.table.delete(row);
let degree_state = self.degree_state.as_mut().unwrap();
degree_state.table.delete(degree);
Ok(())
}
pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
if let Some(mut entry) = self.inner.get_mut(key) {
let pk = (&value)
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
let inequality_key = self
.inequality_key_desc
.as_ref()
.map(|desc| desc.serialize_inequal_key_from_row(&value));
entry
.remove(pk, inequality_key.as_ref())
.with_context(|| self.state.error_context(&value))?;
}
self.state.table.delete(value);
Ok(())
}
pub fn update_state(&mut self, key: &K, state: HashValueType) {
self.inner.put(key.clone(), HashValueWrapper(Some(state)));
}
fn manipulate_degree(
&mut self,
join_row_ref: &mut StateValueType,
join_row: &mut JoinRow<OwnedRow>,
action: impl Fn(&mut DegreeType),
) {
let old_degree = join_row
.to_table_rows(&self.state.order_key_indices)
.1
.into_owned_row();
action(&mut join_row_ref.degree);
action(&mut join_row.degree);
let new_degree = join_row.to_table_rows(&self.state.order_key_indices).1;
let degree_state = self.degree_state.as_mut().unwrap();
degree_state.table.update(old_degree, new_degree);
}
pub fn inc_degree(
&mut self,
join_row_ref: &mut StateValueType,
join_row: &mut JoinRow<OwnedRow>,
) {
self.manipulate_degree(join_row_ref, join_row, |d| *d += 1)
}
pub fn dec_degree(
&mut self,
join_row_ref: &mut StateValueType,
join_row: &mut JoinRow<OwnedRow>,
) {
self.manipulate_degree(join_row_ref, join_row, |d| {
*d = d.checked_sub(1).unwrap_or_else(|| {
consistency_panic!("Tried to decrement zero join row degree");
0
});
})
}
pub fn evict(&mut self) {
self.inner.evict();
}
pub fn entry_count(&self) -> usize {
self.inner.len()
}
pub fn null_matched(&self) -> &K::Bitmap {
&self.null_matched
}
pub fn table_id(&self) -> u32 {
self.state.table.table_id()
}
pub fn join_key_data_types(&self) -> &[DataType] {
&self.join_key_data_types
}
pub fn check_inequal_key_null(&self, row: &impl Row) -> bool {
let desc = self.inequality_key_desc.as_ref().unwrap();
row.datum_at(desc.idx).is_none()
}
pub fn serialize_inequal_key_from_row(&self, row: impl Row) -> InequalKeyType {
self.inequality_key_desc
.as_ref()
.unwrap()
.serialize_inequal_key_from_row(&row)
}
pub fn serialize_pk_from_row(&self, row: impl Row) -> PkType {
row.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer)
}
}
use risingwave_common_estimate_size::KvSize;
use thiserror::Error;
use super::*;
#[derive(Default)]
pub struct JoinEntryState {
cached: JoinRowSet<PkType, StateValueType>,
inequality_index: JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>>,
kv_heap_size: KvSize,
}
impl EstimateSize for JoinEntryState {
fn estimated_heap_size(&self) -> usize {
self.kv_heap_size.size()
}
}
#[derive(Error, Debug)]
pub enum JoinEntryError {
#[error("double inserting a join state entry")]
Occupied,
#[error("removing a join state entry but it is not in the cache")]
Remove,
#[error("retrieving a pk from the inequality index but it is not in the cache")]
InequalIndex,
}
impl JoinEntryState {
pub fn insert(
&mut self,
key: PkType,
value: StateValueType,
inequality_key: Option<InequalKeyType>,
) -> Result<&mut StateValueType, JoinEntryError> {
let mut removed = false;
if !enable_strict_consistency() {
if let Some(old_value) = self.cached.remove(&key) {
if let Some(inequality_key) = inequality_key.as_ref() {
self.remove_pk_from_inequality_index(&key, inequality_key);
}
self.kv_heap_size.sub(&key, &old_value);
removed = true;
}
}
self.kv_heap_size.add(&key, &value);
if let Some(inequality_key) = inequality_key {
self.insert_pk_to_inequality_index(key.clone(), inequality_key);
}
let ret = self.cached.try_insert(key.clone(), value);
if !enable_strict_consistency() {
assert!(ret.is_ok(), "we have removed existing entry, if any");
if removed {
consistency_error!(?key, "double inserting a join state entry");
}
}
ret.map_err(|_| JoinEntryError::Occupied)
}
pub fn remove(
&mut self,
pk: PkType,
inequality_key: Option<&InequalKeyType>,
) -> Result<(), JoinEntryError> {
if let Some(value) = self.cached.remove(&pk) {
self.kv_heap_size.sub(&pk, &value);
if let Some(inequality_key) = inequality_key {
self.remove_pk_from_inequality_index(&pk, inequality_key);
}
Ok(())
} else if enable_strict_consistency() {
Err(JoinEntryError::Remove)
} else {
consistency_error!(?pk, "removing a join state entry but it's not in the cache");
Ok(())
}
}
fn remove_pk_from_inequality_index(&mut self, pk: &PkType, inequality_key: &InequalKeyType) {
if let Some(pk_set) = self.inequality_index.get_mut(inequality_key) {
if pk_set.remove(pk).is_none() {
if enable_strict_consistency() {
panic!("removing a pk that it not in the inequality index");
} else {
consistency_error!(?pk, "removing a pk that it not in the inequality index");
};
} else {
self.kv_heap_size.sub(pk, &());
}
if pk_set.is_empty() {
self.inequality_index.remove(inequality_key);
}
}
}
fn insert_pk_to_inequality_index(&mut self, pk: PkType, inequality_key: InequalKeyType) {
if let Some(pk_set) = self.inequality_index.get_mut(&inequality_key) {
let pk_size = pk.estimated_size();
if pk_set.try_insert(pk, ()).is_err() {
if enable_strict_consistency() {
panic!("inserting a pk that it already in the inequality index");
} else {
consistency_error!("inserting a pk that it already in the inequality index");
};
} else {
self.kv_heap_size.add_size(pk_size);
}
} else {
let mut pk_set = JoinRowSet::default();
pk_set.try_insert(pk, ()).unwrap();
self.inequality_index
.try_insert(inequality_key, pk_set)
.unwrap();
}
}
pub fn get(
&self,
pk: &PkType,
data_types: &[DataType],
) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
self.cached
.get(pk)
.map(|encoded| encoded.decode(data_types))
}
pub fn values_mut<'a>(
&'a mut self,
data_types: &'a [DataType],
) -> impl Iterator<
Item = (
&'a mut StateValueType,
StreamExecutorResult<JoinRow<OwnedRow>>,
),
> + 'a {
self.cached.values_mut().map(|encoded| {
let decoded = encoded.decode(data_types);
(encoded, decoded)
})
}
pub fn len(&self) -> usize {
self.cached.len()
}
pub fn range_by_inequality<'a, R>(
&'a self,
range: R,
data_types: &'a [DataType],
) -> impl Iterator<Item = StreamExecutorResult<JoinRow<OwnedRow>>> + 'a
where
R: RangeBounds<InequalKeyType> + 'a,
{
self.inequality_index.range(range).flat_map(|(_, pk_set)| {
pk_set
.keys()
.flat_map(|pk| self.get_by_indexed_pk(pk, data_types))
})
}
pub fn upper_bound_by_inequality<'a>(
&'a self,
bound: Bound<&InequalKeyType>,
data_types: &'a [DataType],
) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
if let Some((_, pk_set)) = self.inequality_index.upper_bound(bound) {
if let Some(pk) = pk_set.first_key_sorted() {
self.get_by_indexed_pk(pk, data_types)
} else {
panic!("pk set for a index record must has at least one element");
}
} else {
None
}
}
pub fn get_by_indexed_pk(
&self,
pk: &PkType,
data_types: &[DataType],
) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>>
where {
if let Some(value) = self.cached.get(pk) {
Some(value.decode(data_types))
} else if enable_strict_consistency() {
Some(Err(anyhow!(JoinEntryError::InequalIndex).into()))
} else {
consistency_error!(?pk, "{}", JoinEntryError::InequalIndex.as_report());
None
}
}
pub fn lower_bound_by_inequality<'a>(
&'a self,
bound: Bound<&InequalKeyType>,
data_types: &'a [DataType],
) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
if let Some((_, pk_set)) = self.inequality_index.lower_bound(bound) {
if let Some(pk) = pk_set.first_key_sorted() {
self.get_by_indexed_pk(pk, data_types)
} else {
panic!("pk set for a index record must has at least one element");
}
} else {
None
}
}
pub fn get_first_by_inequality<'a>(
&'a self,
inequality_key: &InequalKeyType,
data_types: &'a [DataType],
) -> Option<StreamExecutorResult<JoinRow<OwnedRow>>> {
if let Some(pk_set) = self.inequality_index.get(inequality_key) {
if let Some(pk) = pk_set.first_key_sorted() {
self.get_by_indexed_pk(pk, data_types)
} else {
panic!("pk set for a index record must has at least one element");
}
} else {
None
}
}
pub fn inequality_index(&self) -> &JoinRowSet<InequalKeyType, JoinRowSet<PkType, ()>> {
&self.inequality_index
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use risingwave_common::array::*;
use risingwave_common::util::iter_util::ZipEqDebug;
use super::*;
fn insert_chunk(
managed_state: &mut JoinEntryState,
pk_indices: &[usize],
col_types: &[DataType],
inequality_key_idx: Option<usize>,
data_chunk: &DataChunk,
) {
let pk_col_type = pk_indices
.iter()
.map(|idx| col_types[*idx].clone())
.collect_vec();
let pk_serializer =
OrderedRowSerde::new(pk_col_type, vec![OrderType::ascending(); pk_indices.len()]);
let inequality_key_type = inequality_key_idx.map(|idx| col_types[idx].clone());
let inequality_key_serializer = inequality_key_type
.map(|data_type| OrderedRowSerde::new(vec![data_type], vec![OrderType::ascending()]));
for row_ref in data_chunk.rows() {
let row: OwnedRow = row_ref.into_owned_row();
let value_indices = (0..row.len() - 1).collect_vec();
let pk = pk_indices.iter().map(|idx| row[*idx].clone()).collect_vec();
let pk = OwnedRow::new(pk)
.project(&value_indices)
.memcmp_serialize(&pk_serializer);
let inequality_key = inequality_key_idx.map(|idx| {
(&row)
.project(&[idx])
.memcmp_serialize(inequality_key_serializer.as_ref().unwrap())
});
let join_row = JoinRow { row, degree: 0 };
managed_state
.insert(pk, join_row.encode(), inequality_key)
.unwrap();
}
}
fn check(
managed_state: &mut JoinEntryState,
col_types: &[DataType],
col1: &[i64],
col2: &[i64],
) {
for ((_, matched_row), (d1, d2)) in managed_state
.values_mut(col_types)
.zip_eq_debug(col1.iter().zip_eq_debug(col2.iter()))
{
let matched_row = matched_row.unwrap();
assert_eq!(matched_row.row[0], Some(ScalarImpl::Int64(*d1)));
assert_eq!(matched_row.row[1], Some(ScalarImpl::Int64(*d2)));
assert_eq!(matched_row.degree, 0);
}
}
#[tokio::test]
async fn test_managed_join_state() {
let mut managed_state = JoinEntryState::default();
let col_types = vec![DataType::Int64, DataType::Int64];
let pk_indices = [0];
let col1 = [3, 2, 1];
let col2 = [4, 5, 6];
let data_chunk1 = DataChunk::from_pretty(
"I I
3 4
2 5
1 6",
);
insert_chunk(
&mut managed_state,
&pk_indices,
&col_types,
None,
&data_chunk1,
);
check(&mut managed_state, &col_types, &col1, &col2);
let col1 = [1, 2, 3, 4, 5];
let col2 = [6, 5, 4, 9, 8];
let data_chunk2 = DataChunk::from_pretty(
"I I
5 8
4 9",
);
insert_chunk(
&mut managed_state,
&pk_indices,
&col_types,
None,
&data_chunk2,
);
check(&mut managed_state, &col_types, &col1, &col2);
}
#[tokio::test]
async fn test_managed_join_state_w_inequality_index() {
let mut managed_state = JoinEntryState::default();
let col_types = vec![DataType::Int64, DataType::Int64];
let pk_indices = [0];
let inequality_key_idx = Some(1);
let inequality_key_serializer =
OrderedRowSerde::new(vec![DataType::Int64], vec![OrderType::ascending()]);
let col1 = [3, 2, 1];
let col2 = [4, 5, 5];
let data_chunk1 = DataChunk::from_pretty(
"I I
3 4
2 5
1 5",
);
insert_chunk(
&mut managed_state,
&pk_indices,
&col_types,
inequality_key_idx,
&data_chunk1,
);
check(&mut managed_state, &col_types, &col1, &col2);
let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(5))])
.memcmp_serialize(&inequality_key_serializer);
let row = managed_state
.upper_bound_by_inequality(Bound::Included(&bound), &col_types)
.unwrap()
.unwrap();
assert_eq!(row.row[0], Some(ScalarImpl::Int64(1)));
let row = managed_state
.upper_bound_by_inequality(Bound::Excluded(&bound), &col_types)
.unwrap()
.unwrap();
assert_eq!(row.row[0], Some(ScalarImpl::Int64(3)));
let col1 = [1, 2, 3, 4, 5];
let col2 = [5, 5, 4, 4, 8];
let data_chunk2 = DataChunk::from_pretty(
"I I
5 8
4 4",
);
insert_chunk(
&mut managed_state,
&pk_indices,
&col_types,
inequality_key_idx,
&data_chunk2,
);
check(&mut managed_state, &col_types, &col1, &col2);
let bound = OwnedRow::new(vec![Some(ScalarImpl::Int64(8))])
.memcmp_serialize(&inequality_key_serializer);
let row = managed_state.lower_bound_by_inequality(Bound::Excluded(&bound), &col_types);
assert!(row.is_none());
}
}