risingwave_storage/vector/
utils.rs1use std::cmp::{Ordering, min};
16use std::collections::BinaryHeap;
17use std::mem::replace;
18
19use crate::vector::VectorDistance;
20
21pub(super) fn compare_distance(first: VectorDistance, second: VectorDistance) -> Ordering {
22 first
23 .partial_cmp(&second)
24 .unwrap_or_else(|| panic!("failed to compare distance {} and {}.", first, second))
25}
26
27fn compare_distance_on_heap<const MAX_HEAP: bool>(
28 first: VectorDistance,
29 second: VectorDistance,
30) -> Ordering {
31 let (first, second) = if MAX_HEAP {
32 (first, second)
33 } else {
34 (second, first)
35 };
36 compare_distance(first, second)
37}
38
39pub(super) struct HeapNode<I, const MAX_HEAP: bool> {
40 distance: VectorDistance,
41 item: I,
42}
43
44impl<I, const MAX_HEAP: bool> PartialEq for HeapNode<I, MAX_HEAP> {
45 fn eq(&self, other: &Self) -> bool {
46 self.distance.eq(&other.distance)
47 }
48}
49
50impl<I, const MAX_HEAP: bool> Eq for HeapNode<I, MAX_HEAP> {}
51
52impl<I, const MAX_HEAP: bool> PartialOrd for HeapNode<I, MAX_HEAP> {
53 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58impl<I, const MAX_HEAP: bool> Ord for HeapNode<I, MAX_HEAP> {
59 fn cmp(&self, other: &Self) -> Ordering {
60 compare_distance_on_heap::<MAX_HEAP>(self.distance, other.distance)
61 }
62}
63
64pub struct DistanceHeap<I, const MAX_HEAP: bool>(BinaryHeap<HeapNode<I, MAX_HEAP>>);
65
66pub type MaxDistanceHeap<I> = DistanceHeap<I, true>;
67pub type MinDistanceHeap<I> = DistanceHeap<I, false>;
68
69fn non_zero_capacity(capacity: usize) -> usize {
70 if capacity == 0 {
71 if cfg!(debug_assertions) {
72 panic!("unexpected 0 capacity");
73 } else {
74 1
75 }
76 } else {
77 capacity
78 }
79}
80
81impl<I, const MAX_HEAP: bool> DistanceHeap<I, MAX_HEAP> {
82 pub fn with_capacity(capacity: usize) -> Self {
83 let capacity = non_zero_capacity(capacity);
84 Self(BinaryHeap::with_capacity(capacity))
85 }
86
87 pub fn push(&mut self, distance: VectorDistance, item: I) {
88 self.0.push(HeapNode { distance, item });
89 }
90
91 pub fn top(&self) -> Option<(VectorDistance, &I)> {
92 self.0.peek().map(|node| (node.distance, &node.item))
93 }
94
95 pub fn pop(&mut self) -> Option<(VectorDistance, I)> {
96 self.0.pop().map(|node| (node.distance, node.item))
97 }
98}
99
100pub struct BoundedNearest<I> {
101 heap: MaxDistanceHeap<I>,
102 capacity: usize,
103}
104
105impl<I> BoundedNearest<I> {
106 pub fn new(capacity: usize) -> Self {
107 Self {
108 heap: DistanceHeap(BinaryHeap::with_capacity(capacity)),
109 capacity,
110 }
111 }
112
113 pub fn furthest(&self) -> Option<(VectorDistance, &I)> {
114 self.heap.top()
115 }
116
117 pub fn insert(
118 &mut self,
119 distance: VectorDistance,
120 get_item: impl FnOnce() -> I,
121 ) -> Option<(VectorDistance, I)> {
122 if self.heap.0.len() >= self.capacity {
123 let mut top = self.heap.0.peek_mut().expect("non-empty");
126 if top.distance > distance {
127 let prev_node = replace(
128 &mut *top,
129 HeapNode {
130 distance,
131 item: get_item(),
132 },
133 );
134 Some((prev_node.distance, prev_node.item))
135 } else {
136 None
137 }
138 } else {
139 self.heap.0.push(HeapNode {
140 distance,
141 item: get_item(),
142 });
143 None
144 }
145 }
146
147 pub fn collect(self) -> Vec<I> {
148 self.collect_with(|_, item| item, None)
149 }
150
151 pub fn collect_with<O>(
152 mut self,
153 mut f: impl FnMut(VectorDistance, I) -> O,
154 limit: Option<usize>,
155 ) -> Vec<O> {
156 let size = self.heap.0.len();
157 let size = if let Some(limit) = limit {
158 min(size, limit)
159 } else {
160 size
161 };
162 let mut vec = Vec::with_capacity(size);
163 let uninit_slice = vec.spare_capacity_mut();
164 while self.heap.0.len() > size {
165 self.heap.pop();
166 }
167 assert_eq!(size, self.heap.0.len());
168 let mut i = size;
169 while let Some(node) = self.heap.0.pop() {
171 i -= 1;
172 unsafe {
175 uninit_slice
176 .get_unchecked_mut(i)
177 .write(f(node.distance, node.item));
178 }
179 }
180 assert_eq!(i, 0);
181 unsafe { vec.set_len(size) }
183 vec
184 }
185
186 pub fn resize(&mut self, new_capacity: usize) {
187 let new_capacity = non_zero_capacity(new_capacity);
188 self.capacity = new_capacity;
189 while self.heap.0.len() > new_capacity {
190 self.heap.pop();
191 }
192 }
193
194 #[expect(clippy::len_without_is_empty)]
195 pub fn len(&self) -> usize {
196 self.heap.0.len()
197 }
198}
199
200impl<'a, I> IntoIterator for &'a BoundedNearest<I> {
201 type Item = (VectorDistance, &'a I);
202
203 type IntoIter = impl Iterator<Item = Self::Item>;
204
205 fn into_iter(self) -> Self::IntoIter {
206 self.heap.0.iter().map(|node| (node.distance, &node.item))
207 }
208}
209
210impl<I> IntoIterator for BoundedNearest<I> {
211 type Item = (VectorDistance, I);
212
213 type IntoIter = impl Iterator<Item = Self::Item>;
214
215 fn into_iter(self) -> Self::IntoIter {
216 self.heap
217 .0
218 .into_iter()
219 .map(|node| (node.distance, node.item))
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use std::cmp::min;
226
227 use itertools::Itertools;
228 use rand::{Rng, rng};
229 use risingwave_common::array::VectorDistanceType;
230
231 use crate::vector::test_utils::{gen_info, top_n};
232 use crate::vector::utils::BoundedNearest;
233
234 fn test_inner(count: usize, n: usize, limit: Option<usize>) {
235 let input = (0..count)
236 .map(|i| (rng().random::<VectorDistanceType>(), i))
237 .collect_vec();
238 let mut nearest = BoundedNearest::new(n);
239 for &(distance, item) in &input {
240 nearest.insert(distance, || item);
241 }
242 let output = nearest.collect_with(|_, i| gen_info(i), limit);
243 let mut expected_output = input
244 .iter()
245 .map(|(distance, i)| (*distance, gen_info(*i)))
246 .collect_vec();
247 let n = if let Some(limit) = limit {
248 min(n, limit)
249 } else {
250 n
251 };
252 top_n(&mut expected_output, n);
253 let expected_output = expected_output
254 .into_iter()
255 .map(|(_, info)| info)
256 .collect_vec();
257 assert_eq!(output, expected_output);
258 }
259
260 #[test]
261 fn test_not_full_top_n() {
262 test_inner(5, 10, None);
263 test_inner(5, 10, Some(3));
264 test_inner(5, 10, Some(5));
265 test_inner(5, 10, Some(7));
266 }
267
268 #[test]
269 fn test_exact_size_top_n() {
270 test_inner(10, 10, None);
271 test_inner(10, 10, Some(8));
272 test_inner(10, 10, Some(10));
273 test_inner(10, 10, Some(12));
274 }
275
276 #[test]
277 fn test_oversize_top_n() {
278 test_inner(20, 10, None);
279 test_inner(20, 10, Some(8));
280 test_inner(20, 10, Some(10));
281 test_inner(20, 10, Some(15));
282 test_inner(20, 10, Some(20));
283 test_inner(20, 10, Some(25));
284 }
285}