risingwave_storage/table/
merge_sort.rs1use std::collections::BinaryHeap;
16use std::collections::binary_heap::PeekMut;
17use std::error::Error;
18
19use futures::{Stream, StreamExt};
20use futures_async_stream::try_stream;
21
22use super::{KeyedChangeLogRow, KeyedRow};
23
24pub trait NodePeek {
25    fn vnode_key(&self) -> &[u8];
26}
27
28impl<K: AsRef<[u8]>> NodePeek for KeyedRow<K> {
29    fn vnode_key(&self) -> &[u8] {
30        self.key()
31    }
32}
33
34impl<K: AsRef<[u8]>> NodePeek for KeyedChangeLogRow<K> {
35    fn vnode_key(&self) -> &[u8] {
36        self.key()
37    }
38}
39
40struct Node<S, R: NodePeek> {
41    stream: S,
42
43    peeked: R,
46}
47
48impl<S, R: NodePeek> PartialEq for Node<S, R> {
49    fn eq(&self, other: &Self) -> bool {
50        match self.peeked.vnode_key() == other.peeked.vnode_key() {
51            true => unreachable!("primary key from different iters should be unique"),
52            false => false,
53        }
54    }
55}
56impl<S, R: NodePeek> Eq for Node<S, R> {}
57
58impl<S, R: NodePeek> PartialOrd for Node<S, R> {
59    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
60        Some(self.cmp(other))
61    }
62}
63
64impl<S, R: NodePeek> Ord for Node<S, R> {
65    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
66        self.peeked
68            .vnode_key()
69            .cmp(other.peeked.vnode_key())
70            .reverse()
71    }
72}
73
74#[try_stream(ok=KO, error=E)]
75pub async fn merge_sort<E, KO, R>(streams: impl IntoIterator<Item = R>)
76where
77    KO: NodePeek + Send + Sync,
78    E: Error,
79    R: Stream<Item = Result<KO, E>> + Unpin,
80{
81    let mut heap = BinaryHeap::new();
82    for mut stream in streams {
83        if let Some(peeked) = stream.next().await.transpose()? {
84            heap.push(Node { stream, peeked });
85        }
86    }
87    while let Some(mut node) = heap.peek_mut() {
88        yield match node.stream.next().await.transpose()? {
90            Some(new_peeked) => std::mem::replace(&mut node.peeked, new_peeked),
92            None => PeekMut::pop(node).peeked,
94        };
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use futures_async_stream::for_await;
101    use rand::random_range;
102    use risingwave_common::hash::VirtualNode;
103    use risingwave_common::row::OwnedRow;
104    use risingwave_common::types::ScalarImpl;
105    use risingwave_hummock_sdk::key::TableKey;
106
107    use super::*;
108    use crate::error::StorageResult;
109
110    fn gen_pk_and_row(i: u8) -> StorageResult<KeyedRow<Vec<u8>>> {
111        let vnode = VirtualNode::from_index(random_range(..VirtualNode::COUNT_FOR_TEST));
112        let mut key = vnode.to_be_bytes().to_vec();
113        key.extend(vec![i]);
114        Ok(KeyedRow::new(
115            TableKey(key),
116            OwnedRow::new(vec![Some(ScalarImpl::Int64(i as _))]),
117        ))
118    }
119
120    #[tokio::test]
121    async fn test_merge_sort() {
122        let streams = vec![
123            futures::stream::iter(vec![
124                gen_pk_and_row(0),
125                gen_pk_and_row(3),
126                gen_pk_and_row(6),
127                gen_pk_and_row(9),
128            ]),
129            futures::stream::iter(vec![
130                gen_pk_and_row(1),
131                gen_pk_and_row(4),
132                gen_pk_and_row(7),
133                gen_pk_and_row(10),
134            ]),
135            futures::stream::iter(vec![
136                gen_pk_and_row(2),
137                gen_pk_and_row(5),
138                gen_pk_and_row(8),
139            ]),
140            futures::stream::iter(vec![]), ];
142
143        let merge_sorted = merge_sort(streams);
144
145        #[for_await]
146        for (i, result) in merge_sorted.enumerate() {
147            let expected = gen_pk_and_row(i as u8).unwrap();
148            let actual = result.unwrap();
149            assert_eq!(actual.key(), expected.key());
150            assert_eq!(actual.into_owned_row(), expected.into_owned_row());
151        }
152    }
153}