risingwave_stream/executor/join/
join_row_set.rs1use std::borrow::Borrow;
16use std::collections::BTreeMap;
17use std::collections::btree_map::OccupiedError as BTreeMapOccupiedError;
18use std::fmt::Debug;
19use std::mem;
20use std::ops::{Bound, RangeBounds};
21
22use auto_enums::auto_enum;
23use enum_as_inner::EnumAsInner;
24
25const MAX_VEC_SIZE: usize = 4;
26
27#[derive(Debug, EnumAsInner)]
28pub enum JoinRowSet<K, V> {
29 BTree(BTreeMap<K, V>),
30 Vec(Vec<(K, V)>),
31}
32
33impl<K, V> Default for JoinRowSet<K, V> {
34 fn default() -> Self {
35 Self::Vec(Vec::new())
36 }
37}
38
39#[derive(Debug)]
40#[allow(dead_code)]
41pub struct VecOccupiedError<'a, K, V> {
42 key: &'a K,
43 old_value: &'a V,
44 new_value: V,
45}
46
47#[derive(Debug)]
48pub enum JoinRowSetOccupiedError<'a, K: Ord, V> {
49 BTree(BTreeMapOccupiedError<'a, K, V>),
50 Vec(VecOccupiedError<'a, K, V>),
51}
52
53impl<K: Ord, V> JoinRowSet<K, V> {
54 pub fn try_insert(
55 &mut self,
56 key: K,
57 value: V,
58 ) -> Result<&'_ mut V, JoinRowSetOccupiedError<'_, K, V>> {
59 if let Self::Vec(inner) = self
60 && inner.len() >= MAX_VEC_SIZE
61 {
62 let btree = BTreeMap::from_iter(inner.drain(..));
63 mem::swap(self, &mut Self::BTree(btree));
64 }
65
66 match self {
67 Self::BTree(inner) => inner
68 .try_insert(key, value)
69 .map_err(JoinRowSetOccupiedError::BTree),
70 Self::Vec(inner) => {
71 if let Some(pos) = inner.iter().position(|elem| elem.0 == key) {
72 Err(JoinRowSetOccupiedError::Vec(VecOccupiedError {
73 key: &inner[pos].0,
74 old_value: &inner[pos].1,
75 new_value: value,
76 }))
77 } else {
78 if inner.capacity() == 0 {
79 inner.reserve_exact(1);
82 }
83 inner.push((key, value));
84 Ok(&mut inner.last_mut().unwrap().1)
85 }
86 }
87 }
88 }
89
90 pub fn remove(&mut self, key: &K) -> Option<V> {
91 let ret = match self {
92 Self::BTree(inner) => inner.remove(key),
93 Self::Vec(inner) => inner
94 .iter()
95 .position(|elem| &elem.0 == key)
96 .map(|pos| inner.swap_remove(pos).1),
97 };
98 if let Self::BTree(inner) = self
99 && inner.len() <= MAX_VEC_SIZE / 2
100 {
101 let btree = mem::take(inner);
102 let vec = Vec::from_iter(btree);
103 mem::swap(self, &mut Self::Vec(vec));
104 }
105 ret
106 }
107
108 pub fn len(&self) -> usize {
109 match self {
110 Self::BTree(inner) => inner.len(),
111 Self::Vec(inner) => inner.len(),
112 }
113 }
114
115 pub fn is_empty(&self) -> bool {
116 match self {
117 Self::BTree(inner) => inner.is_empty(),
118 Self::Vec(inner) => inner.is_empty(),
119 }
120 }
121
122 #[auto_enum(Iterator)]
123 pub fn values_mut(&mut self) -> impl Iterator<Item = &'_ mut V> {
124 match self {
125 Self::BTree(inner) => inner.values_mut(),
126 Self::Vec(inner) => inner.iter_mut().map(|(_, v)| v),
127 }
128 }
129
130 #[auto_enum(Iterator)]
131 pub fn keys(&self) -> impl Iterator<Item = &K> {
132 match self {
133 Self::BTree(inner) => inner.keys(),
134 Self::Vec(inner) => inner.iter().map(|(k, _v)| k),
135 }
136 }
137
138 #[auto_enum(Iterator)]
139 pub fn range<T, R>(&self, range: R) -> impl Iterator<Item = (&K, &V)>
140 where
141 T: Ord + ?Sized,
142 K: Borrow<T> + Ord,
143 R: RangeBounds<T>,
144 {
145 match self {
146 Self::BTree(inner) => inner.range(range),
147 Self::Vec(inner) => inner
148 .iter()
149 .filter(move |(k, _)| range.contains(k.borrow()))
150 .map(|(k, v)| (k, v)),
151 }
152 }
153
154 pub fn lower_bound_key(&self, bound: Bound<&K>) -> Option<&K> {
155 self.lower_bound(bound).map(|(k, _v)| k)
156 }
157
158 pub fn upper_bound_key(&self, bound: Bound<&K>) -> Option<&K> {
159 self.upper_bound(bound).map(|(k, _v)| k)
160 }
161
162 pub fn lower_bound(&self, bound: Bound<&K>) -> Option<(&K, &V)> {
163 match self {
164 Self::BTree(inner) => inner.lower_bound(bound).next(),
165 Self::Vec(inner) => inner
166 .iter()
167 .filter(|(k, _)| (bound, Bound::Unbounded).contains(k))
168 .min_by_key(|(k, _)| k)
169 .map(|(k, v)| (k, v)),
170 }
171 }
172
173 pub fn upper_bound(&self, bound: Bound<&K>) -> Option<(&K, &V)> {
174 match self {
175 Self::BTree(inner) => inner.upper_bound(bound).prev(),
176 Self::Vec(inner) => inner
177 .iter()
178 .filter(|(k, _)| (Bound::Unbounded, bound).contains(k))
179 .max_by_key(|(k, _)| k)
180 .map(|(k, v)| (k, v)),
181 }
182 }
183
184 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
185 match self {
186 Self::BTree(inner) => inner.get_mut(key),
187 Self::Vec(inner) => inner.iter_mut().find(|(k, _)| k == key).map(|(_, v)| v),
188 }
189 }
190
191 pub fn get(&self, key: &K) -> Option<&V> {
192 match self {
193 Self::BTree(inner) => inner.get(key),
194 Self::Vec(inner) => inner.iter().find(|(k, _)| k == key).map(|(_, v)| v),
195 }
196 }
197
198 pub fn first_key_sorted(&self) -> Option<&K> {
200 match self {
201 Self::BTree(inner) => inner.first_key_value().map(|(k, _)| k),
202 Self::Vec(inner) => inner.iter().map(|(k, _)| k).min(),
203 }
204 }
205
206 pub fn second_key_sorted(&self) -> Option<&K> {
208 match self {
209 Self::BTree(inner) => inner.iter().nth(1).map(|(k, _)| k),
210 Self::Vec(inner) => {
211 let mut res = None;
212 let mut smallest = None;
213 for (k, _) in inner {
214 if let Some(smallest_k) = smallest {
215 if k < smallest_k {
216 res = Some(smallest_k);
217 smallest = Some(k);
218 } else if let Some(res_k) = res {
219 if k < res_k {
220 res = Some(k);
221 }
222 } else {
223 res = Some(k);
224 }
225 } else {
226 smallest = Some(k);
227 }
228 }
229 res
230 }
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 #[test]
239 fn test_join_row_set_bounds() {
240 let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
241
242 assert!(join_row_set.try_insert(1, 10).is_ok());
244 assert!(join_row_set.try_insert(2, 20).is_ok());
245 assert!(join_row_set.try_insert(3, 30).is_ok());
246
247 assert_eq!(join_row_set.lower_bound_key(Bound::Included(&2)), Some(&2));
249 assert_eq!(join_row_set.lower_bound_key(Bound::Excluded(&2)), Some(&3));
250
251 assert_eq!(join_row_set.upper_bound_key(Bound::Included(&2)), Some(&2));
253 assert_eq!(join_row_set.upper_bound_key(Bound::Excluded(&2)), Some(&1));
254 }
255
256 #[test]
257 fn test_join_row_set_first_and_second_key_sorted() {
258 {
259 let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
260
261 assert!(join_row_set.try_insert(3, 30).is_ok());
263 assert!(join_row_set.try_insert(1, 10).is_ok());
264 assert!(join_row_set.try_insert(2, 20).is_ok());
265
266 assert_eq!(join_row_set.first_key_sorted(), Some(&1));
268
269 assert_eq!(join_row_set.second_key_sorted(), Some(&2));
271 }
272 {
273 let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
274
275 assert!(join_row_set.try_insert(1, 10).is_ok());
277 assert!(join_row_set.try_insert(2, 20).is_ok());
278
279 assert_eq!(join_row_set.first_key_sorted(), Some(&1));
281
282 assert_eq!(join_row_set.second_key_sorted(), Some(&2));
284 }
285 }
286}