risingwave_storage/table/
merge_sort.rs1use std::collections::BinaryHeap;
16use std::collections::binary_heap::PeekMut;
17use std::error::Error;
18
19use futures::{Stream, StreamExt};
20use futures_async_stream::try_stream;
21
22use super::{KeyedChangeLogRow, KeyedRow};
23
24pub trait NodePeek {
25 fn vnode_key(&self) -> &[u8];
26}
27
28impl<K: AsRef<[u8]>> NodePeek for KeyedRow<K> {
29 fn vnode_key(&self) -> &[u8] {
30 self.key()
31 }
32}
33
34impl<K: AsRef<[u8]>> NodePeek for KeyedChangeLogRow<K> {
35 fn vnode_key(&self) -> &[u8] {
36 self.key()
37 }
38}
39
40struct Node<S, R: NodePeek> {
41 stream: S,
42
43 peeked: R,
46}
47
48impl<S, R: NodePeek> PartialEq for Node<S, R> {
49 fn eq(&self, other: &Self) -> bool {
50 match self.peeked.vnode_key() == other.peeked.vnode_key() {
51 true => unreachable!("primary key from different iters should be unique"),
52 false => false,
53 }
54 }
55}
56impl<S, R: NodePeek> Eq for Node<S, R> {}
57
58impl<S, R: NodePeek> PartialOrd for Node<S, R> {
59 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
60 Some(self.cmp(other))
61 }
62}
63
64impl<S, R: NodePeek> Ord for Node<S, R> {
65 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
66 self.peeked
68 .vnode_key()
69 .cmp(other.peeked.vnode_key())
70 .reverse()
71 }
72}
73
74#[try_stream(ok=KO, error=E)]
75pub async fn merge_sort<E, KO, R>(streams: impl IntoIterator<Item = R>)
76where
77 KO: NodePeek + Send + Sync,
78 E: Error,
79 R: Stream<Item = Result<KO, E>> + Unpin,
80{
81 let mut heap = BinaryHeap::new();
82 for mut stream in streams {
83 if let Some(peeked) = stream.next().await.transpose()? {
84 heap.push(Node { stream, peeked });
85 }
86 }
87 while let Some(mut node) = heap.peek_mut() {
88 yield match node.stream.next().await.transpose()? {
90 Some(new_peeked) => std::mem::replace(&mut node.peeked, new_peeked),
92 None => PeekMut::pop(node).peeked,
94 };
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use futures_async_stream::for_await;
101 use rand::random_range;
102 use risingwave_common::hash::VirtualNode;
103 use risingwave_common::row::OwnedRow;
104 use risingwave_common::types::ScalarImpl;
105 use risingwave_hummock_sdk::key::TableKey;
106
107 use super::*;
108 use crate::error::StorageResult;
109
110 fn gen_pk_and_row(i: u8) -> StorageResult<KeyedRow<Vec<u8>>> {
111 let vnode = VirtualNode::from_index(random_range(..VirtualNode::COUNT_FOR_TEST));
112 let mut key = vnode.to_be_bytes().to_vec();
113 key.extend(vec![i]);
114 Ok(KeyedRow::new(
115 TableKey(key),
116 OwnedRow::new(vec![Some(ScalarImpl::Int64(i as _))]),
117 ))
118 }
119
120 #[tokio::test]
121 async fn test_merge_sort() {
122 let streams = vec![
123 futures::stream::iter(vec![
124 gen_pk_and_row(0),
125 gen_pk_and_row(3),
126 gen_pk_and_row(6),
127 gen_pk_and_row(9),
128 ]),
129 futures::stream::iter(vec![
130 gen_pk_and_row(1),
131 gen_pk_and_row(4),
132 gen_pk_and_row(7),
133 gen_pk_and_row(10),
134 ]),
135 futures::stream::iter(vec![
136 gen_pk_and_row(2),
137 gen_pk_and_row(5),
138 gen_pk_and_row(8),
139 ]),
140 futures::stream::iter(vec![]), ];
142
143 let merge_sorted = merge_sort(streams);
144
145 #[for_await]
146 for (i, result) in merge_sorted.enumerate() {
147 let expected = gen_pk_and_row(i as u8).unwrap();
148 let actual = result.unwrap();
149 assert_eq!(actual.key(), expected.key());
150 assert_eq!(actual.into_owned_row(), expected.into_owned_row());
151 }
152 }
153}