risingwave_storage/vector/
utils.rs1use std::cmp::{Ordering, min};
16use std::collections::BinaryHeap;
17use std::mem::replace;
18
19use crate::vector::VectorDistance;
20
21pub(super) fn compare_distance(first: VectorDistance, second: VectorDistance) -> Ordering {
22 first
23 .partial_cmp(&second)
24 .unwrap_or_else(|| panic!("failed to compare distance {} and {}.", first, second))
25}
26
27fn compare_distance_on_heap<const MAX_HEAP: bool>(
28 first: VectorDistance,
29 second: VectorDistance,
30) -> Ordering {
31 let (first, second) = if MAX_HEAP {
32 (first, second)
33 } else {
34 (second, first)
35 };
36 compare_distance(first, second)
37}
38
39pub(super) struct HeapNode<I, const MAX_HEAP: bool> {
40 distance: VectorDistance,
41 item: I,
42}
43
44impl<I, const MAX_HEAP: bool> PartialEq for HeapNode<I, MAX_HEAP> {
45 fn eq(&self, other: &Self) -> bool {
46 self.distance.eq(&other.distance)
47 }
48}
49
50impl<I, const MAX_HEAP: bool> Eq for HeapNode<I, MAX_HEAP> {}
51
52impl<I, const MAX_HEAP: bool> PartialOrd for HeapNode<I, MAX_HEAP> {
53 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58impl<I, const MAX_HEAP: bool> Ord for HeapNode<I, MAX_HEAP> {
59 fn cmp(&self, other: &Self) -> Ordering {
60 compare_distance_on_heap::<MAX_HEAP>(self.distance, other.distance)
61 }
62}
63
64pub struct DistanceHeap<I, const MAX_HEAP: bool>(BinaryHeap<HeapNode<I, MAX_HEAP>>);
65
66pub type MaxDistanceHeap<I> = DistanceHeap<I, true>;
67pub type MinDistanceHeap<I> = DistanceHeap<I, false>;
68
69impl<I, const MAX_HEAP: bool> DistanceHeap<I, MAX_HEAP> {
70 pub fn with_capacity(capacity: usize) -> Self {
71 Self(BinaryHeap::with_capacity(capacity))
72 }
73
74 pub fn push(&mut self, distance: VectorDistance, item: I) {
75 self.0.push(HeapNode { distance, item });
76 }
77
78 pub fn top(&self) -> Option<(VectorDistance, &I)> {
79 self.0.peek().map(|node| (node.distance, &node.item))
80 }
81
82 pub fn pop(&mut self) -> Option<(VectorDistance, I)> {
83 self.0.pop().map(|node| (node.distance, node.item))
84 }
85}
86
87pub struct BoundedNearest<I> {
88 heap: MaxDistanceHeap<I>,
89 capacity: usize,
90}
91
92impl<I> BoundedNearest<I> {
93 pub fn new(capacity: usize) -> Self {
94 Self {
95 heap: DistanceHeap(BinaryHeap::with_capacity(capacity)),
96 capacity,
97 }
98 }
99
100 pub fn furthest(&self) -> Option<(VectorDistance, &I)> {
101 self.heap.top()
102 }
103
104 pub fn insert(
105 &mut self,
106 distance: VectorDistance,
107 get_item: impl FnOnce() -> I,
108 ) -> Option<(VectorDistance, I)> {
109 if self.heap.0.len() >= self.capacity {
110 let mut top = self.heap.0.peek_mut().expect("non-empty");
111 if top.distance > distance {
112 let prev_node = replace(
113 &mut *top,
114 HeapNode {
115 distance,
116 item: get_item(),
117 },
118 );
119 Some((prev_node.distance, prev_node.item))
120 } else {
121 None
122 }
123 } else {
124 self.heap.0.push(HeapNode {
125 distance,
126 item: get_item(),
127 });
128 None
129 }
130 }
131
132 pub fn collect(self) -> Vec<I> {
133 self.collect_with(|item| item, None)
134 }
135
136 pub fn collect_with<O>(mut self, mut f: impl FnMut(I) -> O, limit: Option<usize>) -> Vec<O> {
137 let size = self.heap.0.len();
138 let size = if let Some(limit) = limit {
139 min(size, limit)
140 } else {
141 size
142 };
143 let mut vec = Vec::with_capacity(size);
144 let uninit_slice = vec.spare_capacity_mut();
145 while self.heap.0.len() > size {
146 self.heap.pop();
147 }
148 assert_eq!(size, self.heap.0.len());
149 let mut i = size;
150 while let Some(node) = self.heap.0.pop() {
152 i -= 1;
153 unsafe {
156 uninit_slice.get_unchecked_mut(i).write(f(node.item));
157 }
158 }
159 assert_eq!(i, 0);
160 unsafe { vec.set_len(size) }
162 vec
163 }
164
165 pub fn resize(&mut self, new_capacity: usize) {
166 self.capacity = new_capacity;
167 while self.heap.0.len() > new_capacity {
168 self.heap.pop();
169 }
170 }
171}
172
173impl<'a, I> IntoIterator for &'a BoundedNearest<I> {
174 type Item = (VectorDistance, &'a I);
175
176 type IntoIter = impl Iterator<Item = Self::Item>;
177
178 fn into_iter(self) -> Self::IntoIter {
179 self.heap.0.iter().map(|node| (node.distance, &node.item))
180 }
181}
182
183impl<I> IntoIterator for BoundedNearest<I> {
184 type Item = (VectorDistance, I);
185
186 type IntoIter = impl Iterator<Item = Self::Item>;
187
188 fn into_iter(self) -> Self::IntoIter {
189 self.heap
190 .0
191 .into_iter()
192 .map(|node| (node.distance, node.item))
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use std::cmp::min;
199
200 use itertools::Itertools;
201 use rand::{Rng, rng};
202
203 use crate::vector::test_utils::{gen_info, top_n};
204 use crate::vector::utils::BoundedNearest;
205
206 fn test_inner(count: usize, n: usize, limit: Option<usize>) {
207 let input = (0..count).map(|i| (rng().random::<f32>(), i)).collect_vec();
208 let mut nearest = BoundedNearest::new(n);
209 for &(distance, item) in &input {
210 nearest.insert(distance, || item);
211 }
212 let output = nearest.collect_with(gen_info, limit);
213 let mut expected_output = input
214 .iter()
215 .map(|(distance, i)| (*distance, gen_info(*i)))
216 .collect_vec();
217 let n = if let Some(limit) = limit {
218 min(n, limit)
219 } else {
220 n
221 };
222 top_n(&mut expected_output, n);
223 let expected_output = expected_output
224 .into_iter()
225 .map(|(_, info)| info)
226 .collect_vec();
227 assert_eq!(output, expected_output);
228 }
229
230 #[test]
231 fn test_not_full_top_n() {
232 test_inner(5, 10, None);
233 test_inner(5, 10, Some(3));
234 test_inner(5, 10, Some(5));
235 test_inner(5, 10, Some(7));
236 }
237
238 #[test]
239 fn test_exact_size_top_n() {
240 test_inner(10, 10, None);
241 test_inner(10, 10, Some(8));
242 test_inner(10, 10, Some(10));
243 test_inner(10, 10, Some(12));
244 }
245
246 #[test]
247 fn test_oversize_top_n() {
248 test_inner(20, 10, None);
249 test_inner(20, 10, Some(8));
250 test_inner(20, 10, Some(10));
251 test_inner(20, 10, Some(15));
252 test_inner(20, 10, Some(20));
253 test_inner(20, 10, Some(25));
254 }
255}