risingwave_stream/executor/join/
join_row_set.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Borrow;
16use std::collections::BTreeMap;
17use std::collections::btree_map::OccupiedError as BTreeMapOccupiedError;
18use std::fmt::Debug;
19use std::mem;
20
21use auto_enums::auto_enum;
22use enum_as_inner::EnumAsInner;
23
24const MAX_VEC_SIZE: usize = 4;
25
26#[derive(Debug, EnumAsInner)]
27pub enum JoinRowSet<K, V> {
28    BTree(BTreeMap<K, V>),
29    Vec(Vec<(K, V)>),
30}
31
32impl<K, V> Default for JoinRowSet<K, V> {
33    fn default() -> Self {
34        Self::Vec(Vec::new())
35    }
36}
37
38#[derive(Debug)]
39#[expect(dead_code)]
40pub struct VecOccupiedError<'a, K, V> {
41    key: &'a K,
42    old_value: &'a V,
43    new_value: V,
44}
45
46#[derive(Debug)]
47pub enum JoinRowSetOccupiedError<'a, K: Ord, V> {
48    BTree(BTreeMapOccupiedError<'a, K, V>),
49    Vec(VecOccupiedError<'a, K, V>),
50}
51
52impl<K: Ord, V> JoinRowSet<K, V> {
53    pub fn try_insert(
54        &mut self,
55        key: K,
56        value: V,
57    ) -> Result<&'_ mut V, JoinRowSetOccupiedError<'_, K, V>> {
58        if let Self::Vec(inner) = self
59            && inner.len() >= MAX_VEC_SIZE
60        {
61            let btree = BTreeMap::from_iter(inner.drain(..));
62            *self = Self::BTree(btree);
63        }
64
65        match self {
66            Self::BTree(inner) => inner
67                .try_insert(key, value)
68                .map_err(JoinRowSetOccupiedError::BTree),
69            Self::Vec(inner) => {
70                if let Some(pos) = inner.iter().position(|elem| elem.0 == key) {
71                    Err(JoinRowSetOccupiedError::Vec(VecOccupiedError {
72                        key: &inner[pos].0,
73                        old_value: &inner[pos].1,
74                        new_value: value,
75                    }))
76                } else {
77                    if inner.capacity() == 0 {
78                        // `Vec` will give capacity 4 when `1 < mem::size_of::<T> <= 1024`
79                        // We only give one for memory optimization
80                        inner.reserve_exact(1);
81                    }
82                    inner.push((key, value));
83                    Ok(&mut inner.last_mut().unwrap().1)
84                }
85            }
86        }
87    }
88
89    pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
90    where
91        K: Borrow<Q>,
92        Q: Ord + ?Sized,
93    {
94        let ret = match self {
95            Self::BTree(inner) => inner.remove(key),
96            Self::Vec(inner) => inner
97                .iter()
98                .position(|elem| elem.0.borrow() == key)
99                .map(|pos| inner.swap_remove(pos).1),
100        };
101        if let Self::BTree(inner) = self
102            && inner.len() <= MAX_VEC_SIZE / 2
103        {
104            let btree = mem::take(inner);
105            let vec = Vec::from_iter(btree);
106            *self = Self::Vec(vec);
107        }
108        ret
109    }
110
111    pub fn len(&self) -> usize {
112        match self {
113            Self::BTree(inner) => inner.len(),
114            Self::Vec(inner) => inner.len(),
115        }
116    }
117
118    pub fn is_empty(&self) -> bool {
119        match self {
120            Self::BTree(inner) => inner.is_empty(),
121            Self::Vec(inner) => inner.is_empty(),
122        }
123    }
124
125    #[auto_enum(Iterator)]
126    pub fn values(&self) -> impl Iterator<Item = &V> {
127        match self {
128            Self::BTree(inner) => inner.values(),
129            Self::Vec(inner) => inner.iter().map(|(_, v)| v),
130        }
131    }
132
133    #[auto_enum(Iterator)]
134    pub fn values_mut(&mut self) -> impl Iterator<Item = &'_ mut V> {
135        match self {
136            Self::BTree(inner) => inner.values_mut(),
137            Self::Vec(inner) => inner.iter_mut().map(|(_, v)| v),
138        }
139    }
140
141    pub fn get(&self, key: &K) -> Option<&V> {
142        match self {
143            Self::BTree(inner) => inner.get(key),
144            Self::Vec(inner) => inner.iter().find(|(k, _)| k == key).map(|(_, v)| v),
145        }
146    }
147
148    /// Returns the key-value pair with smallest key in the map.
149    pub fn first_key_sorted(&self) -> Option<&K> {
150        match self {
151            Self::BTree(inner) => inner.first_key_value().map(|(k, _)| k),
152            Self::Vec(inner) => inner.iter().map(|(k, _)| k).min(),
153        }
154    }
155
156    /// Returns the smallest and second-smallest keys in the map.
157    pub fn first_two_key_sorted(&self) -> (Option<&K>, Option<&K>) {
158        match self {
159            Self::BTree(inner) => {
160                let mut iter = inner.keys();
161                (iter.next(), iter.next())
162            }
163            Self::Vec(inner) => {
164                let mut smallest = None;
165                let mut second = None;
166                for (k, _) in inner {
167                    if let Some(smallest_k) = smallest {
168                        if k < smallest_k {
169                            second = Some(smallest_k);
170                            smallest = Some(k);
171                        } else if second.is_none_or(|s| k < s) {
172                            second = Some(k);
173                        }
174                    } else {
175                        smallest = Some(k);
176                    }
177                }
178                (smallest, second)
179            }
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_join_row_set_first_two_key_sorted() {
190        {
191            let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
192
193            // Insert elements
194            assert!(join_row_set.try_insert(3, 30).is_ok());
195            assert!(join_row_set.try_insert(1, 10).is_ok());
196            assert!(join_row_set.try_insert(2, 20).is_ok());
197
198            assert_eq!(join_row_set.first_two_key_sorted(), (Some(&1), Some(&2)));
199        }
200        {
201            let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
202
203            // Insert elements
204            assert!(join_row_set.try_insert(1, 10).is_ok());
205            assert!(join_row_set.try_insert(2, 20).is_ok());
206
207            assert_eq!(join_row_set.first_two_key_sorted(), (Some(&1), Some(&2)));
208        }
209        {
210            let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
211
212            assert!(join_row_set.try_insert(1, 10).is_ok());
213
214            assert_eq!(join_row_set.first_two_key_sorted(), (Some(&1), None));
215        }
216    }
217}