risingwave_storage/table/
merge_sort.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BinaryHeap;
16use std::collections::binary_heap::PeekMut;
17use std::error::Error;
18
19use futures::{Stream, StreamExt};
20use futures_async_stream::try_stream;
21
22use super::{KeyedChangeLogRow, KeyedRow};
23
24pub trait NodePeek {
25    fn vnode_key(&self) -> &[u8];
26}
27
28impl<K: AsRef<[u8]>> NodePeek for KeyedRow<K> {
29    fn vnode_key(&self) -> &[u8] {
30        self.key()
31    }
32}
33
34impl<K: AsRef<[u8]>> NodePeek for KeyedChangeLogRow<K> {
35    fn vnode_key(&self) -> &[u8] {
36        self.key()
37    }
38}
39
40struct Node<S, R: NodePeek> {
41    stream: S,
42
43    /// The next item polled from `stream` previously. Since the `eq` and `cmp` must be synchronous
44    /// functions, we need to implement peeking manually.
45    peeked: R,
46}
47
48impl<S, R: NodePeek> PartialEq for Node<S, R> {
49    fn eq(&self, other: &Self) -> bool {
50        match self.peeked.vnode_key() == other.peeked.vnode_key() {
51            true => unreachable!("primary key from different iters should be unique"),
52            false => false,
53        }
54    }
55}
56impl<S, R: NodePeek> Eq for Node<S, R> {}
57
58impl<S, R: NodePeek> PartialOrd for Node<S, R> {
59    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
60        Some(self.cmp(other))
61    }
62}
63
64impl<S, R: NodePeek> Ord for Node<S, R> {
65    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
66        // The heap is a max heap, so we need to reverse the order.
67        self.peeked
68            .vnode_key()
69            .cmp(other.peeked.vnode_key())
70            .reverse()
71    }
72}
73
74#[try_stream(ok=KO, error=E)]
75pub async fn merge_sort<E, KO, R>(streams: impl IntoIterator<Item = R>)
76where
77    KO: NodePeek + Send + Sync,
78    E: Error,
79    R: Stream<Item = Result<KO, E>> + Unpin,
80{
81    let mut heap = BinaryHeap::new();
82    for mut stream in streams {
83        if let Some(peeked) = stream.next().await.transpose()? {
84            heap.push(Node { stream, peeked });
85        }
86    }
87    while let Some(mut node) = heap.peek_mut() {
88        // Note: If the `next` returns `Err`, we'll fail to yield the previous item.
89        yield match node.stream.next().await.transpose()? {
90            // There still remains data in the stream, take and update the peeked value.
91            Some(new_peeked) => std::mem::replace(&mut node.peeked, new_peeked),
92            // This stream is exhausted, remove it from the heap.
93            None => PeekMut::pop(node).peeked,
94        };
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use futures_async_stream::for_await;
101    use rand::random_range;
102    use risingwave_common::hash::VirtualNode;
103    use risingwave_common::row::OwnedRow;
104    use risingwave_common::types::ScalarImpl;
105    use risingwave_hummock_sdk::key::TableKey;
106
107    use super::*;
108    use crate::error::StorageResult;
109
110    fn gen_pk_and_row(i: u8) -> StorageResult<KeyedRow<Vec<u8>>> {
111        let vnode = VirtualNode::from_index(random_range(..VirtualNode::COUNT_FOR_TEST));
112        let mut key = vnode.to_be_bytes().to_vec();
113        key.extend(vec![i]);
114        Ok(KeyedRow::new(
115            TableKey(key),
116            OwnedRow::new(vec![Some(ScalarImpl::Int64(i as _))]),
117        ))
118    }
119
120    #[tokio::test]
121    async fn test_merge_sort() {
122        let streams = vec![
123            futures::stream::iter(vec![
124                gen_pk_and_row(0),
125                gen_pk_and_row(3),
126                gen_pk_and_row(6),
127                gen_pk_and_row(9),
128            ]),
129            futures::stream::iter(vec![
130                gen_pk_and_row(1),
131                gen_pk_and_row(4),
132                gen_pk_and_row(7),
133                gen_pk_and_row(10),
134            ]),
135            futures::stream::iter(vec![
136                gen_pk_and_row(2),
137                gen_pk_and_row(5),
138                gen_pk_and_row(8),
139            ]),
140            futures::stream::iter(vec![]), // empty stream
141        ];
142
143        let merge_sorted = merge_sort(streams);
144
145        #[for_await]
146        for (i, result) in merge_sorted.enumerate() {
147            let expected = gen_pk_and_row(i as u8).unwrap();
148            let actual = result.unwrap();
149            assert_eq!(actual.key(), expected.key());
150            assert_eq!(actual.into_owned_row(), expected.into_owned_row());
151        }
152    }
153}