risingwave_storage/table/
merge_sort.rsuse std::collections::binary_heap::PeekMut;
use std::collections::BinaryHeap;
use std::error::Error;
use futures::{Stream, StreamExt};
use futures_async_stream::try_stream;
use super::{KeyedChangeLogRow, KeyedRow};
pub trait NodePeek {
fn vnode_key(&self) -> &[u8];
}
impl<K: AsRef<[u8]>> NodePeek for KeyedRow<K> {
fn vnode_key(&self) -> &[u8] {
self.key()
}
}
impl<K: AsRef<[u8]>> NodePeek for KeyedChangeLogRow<K> {
fn vnode_key(&self) -> &[u8] {
self.key()
}
}
struct Node<S, R: NodePeek> {
stream: S,
peeked: R,
}
impl<S, R: NodePeek> PartialEq for Node<S, R> {
fn eq(&self, other: &Self) -> bool {
match self.peeked.vnode_key() == other.peeked.vnode_key() {
true => unreachable!("primary key from different iters should be unique"),
false => false,
}
}
}
impl<S, R: NodePeek> Eq for Node<S, R> {}
impl<S, R: NodePeek> PartialOrd for Node<S, R> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<S, R: NodePeek> Ord for Node<S, R> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.peeked
.vnode_key()
.cmp(other.peeked.vnode_key())
.reverse()
}
}
#[try_stream(ok=KO, error=E)]
pub async fn merge_sort<E, KO, R>(streams: impl IntoIterator<Item = R>)
where
KO: NodePeek + Send + Sync,
E: Error,
R: Stream<Item = Result<KO, E>> + Unpin,
{
let mut heap = BinaryHeap::new();
for mut stream in streams {
if let Some(peeked) = stream.next().await.transpose()? {
heap.push(Node { stream, peeked });
}
}
while let Some(mut node) = heap.peek_mut() {
yield match node.stream.next().await.transpose()? {
Some(new_peeked) => std::mem::replace(&mut node.peeked, new_peeked),
None => PeekMut::pop(node).peeked,
};
}
}
#[cfg(test)]
mod tests {
use futures_async_stream::for_await;
use rand::random;
use risingwave_common::hash::VirtualNode;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::ScalarImpl;
use risingwave_hummock_sdk::key::TableKey;
use super::*;
use crate::error::StorageResult;
fn gen_pk_and_row(i: u8) -> StorageResult<KeyedRow<Vec<u8>>> {
let vnode = VirtualNode::from_index(random::<usize>() % VirtualNode::COUNT_FOR_TEST);
let mut key = vnode.to_be_bytes().to_vec();
key.extend(vec![i]);
Ok(KeyedRow::new(
TableKey(key),
OwnedRow::new(vec![Some(ScalarImpl::Int64(i as _))]),
))
}
#[tokio::test]
async fn test_merge_sort() {
let streams = vec![
futures::stream::iter(vec![
gen_pk_and_row(0),
gen_pk_and_row(3),
gen_pk_and_row(6),
gen_pk_and_row(9),
]),
futures::stream::iter(vec![
gen_pk_and_row(1),
gen_pk_and_row(4),
gen_pk_and_row(7),
gen_pk_and_row(10),
]),
futures::stream::iter(vec![
gen_pk_and_row(2),
gen_pk_and_row(5),
gen_pk_and_row(8),
]),
futures::stream::iter(vec![]), ];
let merge_sorted = merge_sort(streams);
#[for_await]
for (i, result) in merge_sorted.enumerate() {
let expected = gen_pk_and_row(i as u8).unwrap();
let actual = result.unwrap();
assert_eq!(actual.key(), expected.key());
assert_eq!(actual.into_owned_row(), expected.into_owned_row());
}
}
}