risingwave_stream/executor/join/
join_row_set.rs1use std::borrow::Borrow;
16use std::collections::BTreeMap;
17use std::collections::btree_map::OccupiedError as BTreeMapOccupiedError;
18use std::fmt::Debug;
19use std::mem;
20
21use auto_enums::auto_enum;
22use enum_as_inner::EnumAsInner;
23
24const MAX_VEC_SIZE: usize = 4;
25
26#[derive(Debug, EnumAsInner)]
27pub enum JoinRowSet<K, V> {
28 BTree(BTreeMap<K, V>),
29 Vec(Vec<(K, V)>),
30}
31
32impl<K, V> Default for JoinRowSet<K, V> {
33 fn default() -> Self {
34 Self::Vec(Vec::new())
35 }
36}
37
38#[derive(Debug)]
39#[expect(dead_code)]
40pub struct VecOccupiedError<'a, K, V> {
41 key: &'a K,
42 old_value: &'a V,
43 new_value: V,
44}
45
46#[derive(Debug)]
47pub enum JoinRowSetOccupiedError<'a, K: Ord, V> {
48 BTree(BTreeMapOccupiedError<'a, K, V>),
49 Vec(VecOccupiedError<'a, K, V>),
50}
51
52impl<K: Ord, V> JoinRowSet<K, V> {
53 pub fn try_insert(
54 &mut self,
55 key: K,
56 value: V,
57 ) -> Result<&'_ mut V, JoinRowSetOccupiedError<'_, K, V>> {
58 if let Self::Vec(inner) = self
59 && inner.len() >= MAX_VEC_SIZE
60 {
61 let btree = BTreeMap::from_iter(inner.drain(..));
62 *self = Self::BTree(btree);
63 }
64
65 match self {
66 Self::BTree(inner) => inner
67 .try_insert(key, value)
68 .map_err(JoinRowSetOccupiedError::BTree),
69 Self::Vec(inner) => {
70 if let Some(pos) = inner.iter().position(|elem| elem.0 == key) {
71 Err(JoinRowSetOccupiedError::Vec(VecOccupiedError {
72 key: &inner[pos].0,
73 old_value: &inner[pos].1,
74 new_value: value,
75 }))
76 } else {
77 if inner.capacity() == 0 {
78 inner.reserve_exact(1);
81 }
82 inner.push((key, value));
83 Ok(&mut inner.last_mut().unwrap().1)
84 }
85 }
86 }
87 }
88
89 pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
90 where
91 K: Borrow<Q>,
92 Q: Ord + ?Sized,
93 {
94 let ret = match self {
95 Self::BTree(inner) => inner.remove(key),
96 Self::Vec(inner) => inner
97 .iter()
98 .position(|elem| elem.0.borrow() == key)
99 .map(|pos| inner.swap_remove(pos).1),
100 };
101 if let Self::BTree(inner) = self
102 && inner.len() <= MAX_VEC_SIZE / 2
103 {
104 let btree = mem::take(inner);
105 let vec = Vec::from_iter(btree);
106 *self = Self::Vec(vec);
107 }
108 ret
109 }
110
111 pub fn len(&self) -> usize {
112 match self {
113 Self::BTree(inner) => inner.len(),
114 Self::Vec(inner) => inner.len(),
115 }
116 }
117
118 pub fn is_empty(&self) -> bool {
119 match self {
120 Self::BTree(inner) => inner.is_empty(),
121 Self::Vec(inner) => inner.is_empty(),
122 }
123 }
124
125 #[auto_enum(Iterator)]
126 pub fn values(&self) -> impl Iterator<Item = &V> {
127 match self {
128 Self::BTree(inner) => inner.values(),
129 Self::Vec(inner) => inner.iter().map(|(_, v)| v),
130 }
131 }
132
133 #[auto_enum(Iterator)]
134 pub fn values_mut(&mut self) -> impl Iterator<Item = &'_ mut V> {
135 match self {
136 Self::BTree(inner) => inner.values_mut(),
137 Self::Vec(inner) => inner.iter_mut().map(|(_, v)| v),
138 }
139 }
140
141 pub fn get(&self, key: &K) -> Option<&V> {
142 match self {
143 Self::BTree(inner) => inner.get(key),
144 Self::Vec(inner) => inner.iter().find(|(k, _)| k == key).map(|(_, v)| v),
145 }
146 }
147
148 pub fn first_key_sorted(&self) -> Option<&K> {
150 match self {
151 Self::BTree(inner) => inner.first_key_value().map(|(k, _)| k),
152 Self::Vec(inner) => inner.iter().map(|(k, _)| k).min(),
153 }
154 }
155
156 pub fn first_two_key_sorted(&self) -> (Option<&K>, Option<&K>) {
158 match self {
159 Self::BTree(inner) => {
160 let mut iter = inner.keys();
161 (iter.next(), iter.next())
162 }
163 Self::Vec(inner) => {
164 let mut smallest = None;
165 let mut second = None;
166 for (k, _) in inner {
167 if let Some(smallest_k) = smallest {
168 if k < smallest_k {
169 second = Some(smallest_k);
170 smallest = Some(k);
171 } else if second.is_none_or(|s| k < s) {
172 second = Some(k);
173 }
174 } else {
175 smallest = Some(k);
176 }
177 }
178 (smallest, second)
179 }
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_join_row_set_first_two_key_sorted() {
190 {
191 let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
192
193 assert!(join_row_set.try_insert(3, 30).is_ok());
195 assert!(join_row_set.try_insert(1, 10).is_ok());
196 assert!(join_row_set.try_insert(2, 20).is_ok());
197
198 assert_eq!(join_row_set.first_two_key_sorted(), (Some(&1), Some(&2)));
199 }
200 {
201 let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
202
203 assert!(join_row_set.try_insert(1, 10).is_ok());
205 assert!(join_row_set.try_insert(2, 20).is_ok());
206
207 assert_eq!(join_row_set.first_two_key_sorted(), (Some(&1), Some(&2)));
208 }
209 {
210 let mut join_row_set: JoinRowSet<i32, i32> = JoinRowSet::default();
211
212 assert!(join_row_set.try_insert(1, 10).is_ok());
213
214 assert_eq!(join_row_set.first_two_key_sorted(), (Some(&1), None));
215 }
216 }
217}