risingwave_stream/executor/join/
join_row_set.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Borrow;
16use std::collections::BTreeMap;
17use std::collections::btree_map::OccupiedError as BTreeMapOccupiedError;
18use std::fmt::Debug;
19use std::mem;
20use std::ops::{Bound, RangeBounds};
21
22use auto_enums::auto_enum;
23use enum_as_inner::EnumAsInner;
24
25const MAX_VEC_SIZE: usize = 4;
26
27#[derive(Debug, EnumAsInner)]
28pub enum JoinRowSet<K, V> {
29    BTree(BTreeMap<K, V>),
30    Vec(Vec<(K, V)>),
31}
32
33impl<K, V> Default for JoinRowSet<K, V> {
34    fn default() -> Self {
35        Self::Vec(Vec::new())
36    }
37}
38
39#[derive(Debug)]
40#[allow(dead_code)]
41pub struct VecOccupiedError<'a, K, V> {
42    key: &'a K,
43    old_value: &'a V,
44    new_value: V,
45}
46
47#[derive(Debug)]
48pub enum JoinRowSetOccupiedError<'a, K: Ord, V> {
49    BTree(BTreeMapOccupiedError<'a, K, V>),
50    Vec(VecOccupiedError<'a, K, V>),
51}
52
53impl<K: Ord, V> JoinRowSet<K, V> {
54    pub fn try_insert(
55        &mut self,
56        key: K,
57        value: V,
58    ) -> Result<&'_ mut V, JoinRowSetOccupiedError<'_, K, V>> {
59        if let Self::Vec(inner) = self
60            && inner.len() >= MAX_VEC_SIZE
61        {
62            let btree = BTreeMap::from_iter(inner.drain(..));
63            mem::swap(self, &mut Self::BTree(btree));
64        }
65
66        match self {
67            Self::BTree(inner) => inner
68                .try_insert(key, value)
69                .map_err(JoinRowSetOccupiedError::BTree),
70            Self::Vec(inner) => {
71                if let Some(pos) = inner.iter().position(|elem| elem.0 == key) {
72                    Err(JoinRowSetOccupiedError::Vec(VecOccupiedError {
73                        key: &inner[pos].0,
74                        old_value: &inner[pos].1,
75                        new_value: value,
76                    }))
77                } else {
78                    if inner.capacity() == 0 {
79                        // `Vec` will give capacity 4 when `1 < mem::size_of::<T> <= 1024`
80                        // We only give one for memory optimization
81                        inner.reserve_exact(1);
82                    }
83                    inner.push((key, value));
84                    Ok(&mut inner.last_mut().unwrap().1)
85                }
86            }
87        }
88    }
89
90    pub fn remove(&mut self, key: &K) -> Option<V> {
91        let ret = match self {
92            Self::BTree(inner) => inner.remove(key),
93            Self::Vec(inner) => inner
94                .iter()
95                .position(|elem| &elem.0 == key)
96                .map(|pos| inner.swap_remove(pos).1),
97        };
98        if let Self::BTree(inner) = self
99            && inner.len() <= MAX_VEC_SIZE / 2
100        {
101            let btree = mem::take(inner);
102            let vec = Vec::from_iter(btree);
103            mem::swap(self, &mut Self::Vec(vec));
104        }
105        ret
106    }
107
108    pub fn len(&self) -> usize {
109        match self {
110            Self::BTree(inner) => inner.len(),
111            Self::Vec(inner) => inner.len(),
112        }
113    }
114
115    pub fn is_empty(&self) -> bool {
116        match self {
117            Self::BTree(inner) => inner.is_empty(),
118            Self::Vec(inner) => inner.is_empty(),
119        }
120    }
121
122    #[auto_enum(Iterator)]
123    pub fn values_mut(&mut self) -> impl Iterator<Item = &'_ mut V> {
124        match self {
125            Self::BTree(inner) => inner.values_mut(),
126            Self::Vec(inner) => inner.iter_mut().map(|(_, v)| v),
127        }
128    }
129
130    #[auto_enum(Iterator)]
131    pub fn keys(&self) -> impl Iterator<Item = &K> {
132        match self {
133            Self::BTree(inner) => inner.keys(),
134            Self::Vec(inner) => inner.iter().map(|(k, _v)| k),
135        }
136    }
137
138    #[auto_enum(Iterator)]
139    pub fn range<T, R>(&self, range: R) -> impl Iterator<Item = (&K, &V)>
140    where
141        T: Ord + ?Sized,
142        K: Borrow<T> + Ord,
143        R: RangeBounds<T>,
144    {
145        match self {
146            Self::BTree(inner) => inner.range(range),
147            Self::Vec(inner) => inner
148                .iter()
149                .filter(move |(k, _)| range.contains(k.borrow()))
150                .map(|(k, v)| (k, v)),
151        }
152    }
153
154    pub fn lower_bound_key(&self, bound: Bound<&K>) -> Option<&K> {
155        self.lower_bound(bound).map(|(k, _v)| k)
156    }
157
158    pub fn upper_bound_key(&self, bound: Bound<&K>) -> Option<&K> {
159        self.upper_bound(bound).map(|(k, _v)| k)
160    }
161
162    pub fn lower_bound(&self, bound: Bound<&K>) -> Option<(&K, &V)> {
163        match self {
164            Self::BTree(inner) => inner.lower_bound(bound).next(),
165            Self::Vec(inner) => inner
166                .iter()
167                .filter(|(k, _)| (bound, Bound::Unbounded).contains(k))
168                .min_by_key(|(k, _)| k)
169                .map(|(k, v)| (k, v)),
170        }
171    }
172
173    pub fn upper_bound(&self, bound: Bound<&K>) -> Option<(&K, &V)> {
174        match self {
175            Self::BTree(inner) => inner.upper_bound(bound).prev(),
176            Self::Vec(inner) => inner
177                .iter()
178                .filter(|(k, _)| (Bound::Unbounded, bound).contains(k))
179                .max_by_key(|(k, _)| k)
180                .map(|(k, v)| (k, v)),
181        }
182    }
183
184    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
185        match self {
186            Self::BTree(inner) => inner.get_mut(key),
187            Self::Vec(inner) => inner.iter_mut().find(|(k, _)| k == key).map(|(_, v)| v),
188        }
189    }
190
191    pub fn get(&self, key: &K) -> Option<&V> {
192        match self {
193            Self::BTree(inner) => inner.get(key),
194            Self::Vec(inner) => inner.iter().find(|(k, _)| k == key).map(|(_, v)| v),
195        }
196    }
197
198    /// Returns the key-value pair with smallest key in the map.
199    pub fn first_key_sorted(&self) -> Option<&K> {
200        match self {
201            Self::BTree(inner) => inner.first_key_value().map(|(k, _)| k),
202            Self::Vec(inner) => inner.iter().map(|(k, _)| k).min(),
203        }
204    }
205
206    /// Returns the key-value pair with the second smallest key in the map.
207    pub fn second_key_sorted(&self) -> Option<&K> {
208        match self {
209            Self::BTree(inner) => inner.iter().nth(1).map(|(k, _)| k),
210            Self::Vec(inner) => {
211                let mut res = None;
212                let mut smallest = None;
213                for (k, _) in inner {
214                    if let Some(smallest_k) = smallest {
215                        if k < smallest_k {
216                            res = Some(smallest_k);
217                            smallest = Some(k);
218                        } else if let Some(res_k) = res {
219                            if k < res_k {
220                                res = Some(k);
221                            }
222                        } else {
223                            res = Some(k);
224                        }
225                    } else {
226                        smallest = Some(k);
227                    }
228                }
229                res
230            }
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    #[test]
239    fn test_join_row_set_bounds() {
240        let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
241
242        // Insert elements
243        assert!(join_row_set.try_insert(1, 10).is_ok());
244        assert!(join_row_set.try_insert(2, 20).is_ok());
245        assert!(join_row_set.try_insert(3, 30).is_ok());
246
247        // Check lower bound
248        assert_eq!(join_row_set.lower_bound_key(Bound::Included(&2)), Some(&2));
249        assert_eq!(join_row_set.lower_bound_key(Bound::Excluded(&2)), Some(&3));
250
251        // Check upper bound
252        assert_eq!(join_row_set.upper_bound_key(Bound::Included(&2)), Some(&2));
253        assert_eq!(join_row_set.upper_bound_key(Bound::Excluded(&2)), Some(&1));
254    }
255
256    #[test]
257    fn test_join_row_set_first_and_second_key_sorted() {
258        {
259            let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
260
261            // Insert elements
262            assert!(join_row_set.try_insert(3, 30).is_ok());
263            assert!(join_row_set.try_insert(1, 10).is_ok());
264            assert!(join_row_set.try_insert(2, 20).is_ok());
265
266            // Check first key sorted
267            assert_eq!(join_row_set.first_key_sorted(), Some(&1));
268
269            // Check second key sorted
270            assert_eq!(join_row_set.second_key_sorted(), Some(&2));
271        }
272        {
273            let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
274
275            // Insert elements
276            assert!(join_row_set.try_insert(1, 10).is_ok());
277            assert!(join_row_set.try_insert(2, 20).is_ok());
278
279            // Check first key sorted
280            assert_eq!(join_row_set.first_key_sorted(), Some(&1));
281
282            // Check second key sorted
283            assert_eq!(join_row_set.second_key_sorted(), Some(&2));
284        }
285    }
286}