risingwave_storage/vector/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::{Ordering, min};
16use std::collections::BinaryHeap;
17use std::mem::replace;
18
19use crate::vector::VectorDistance;
20
21pub(super) fn compare_distance(first: VectorDistance, second: VectorDistance) -> Ordering {
22    first
23        .partial_cmp(&second)
24        .unwrap_or_else(|| panic!("failed to compare distance {} and {}.", first, second))
25}
26
27fn compare_distance_on_heap<const MAX_HEAP: bool>(
28    first: VectorDistance,
29    second: VectorDistance,
30) -> Ordering {
31    let (first, second) = if MAX_HEAP {
32        (first, second)
33    } else {
34        (second, first)
35    };
36    compare_distance(first, second)
37}
38
39pub(super) struct HeapNode<I, const MAX_HEAP: bool> {
40    distance: VectorDistance,
41    item: I,
42}
43
44impl<I, const MAX_HEAP: bool> PartialEq for HeapNode<I, MAX_HEAP> {
45    fn eq(&self, other: &Self) -> bool {
46        self.distance.eq(&other.distance)
47    }
48}
49
50impl<I, const MAX_HEAP: bool> Eq for HeapNode<I, MAX_HEAP> {}
51
52impl<I, const MAX_HEAP: bool> PartialOrd for HeapNode<I, MAX_HEAP> {
53    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl<I, const MAX_HEAP: bool> Ord for HeapNode<I, MAX_HEAP> {
59    fn cmp(&self, other: &Self) -> Ordering {
60        compare_distance_on_heap::<MAX_HEAP>(self.distance, other.distance)
61    }
62}
63
64pub struct DistanceHeap<I, const MAX_HEAP: bool>(BinaryHeap<HeapNode<I, MAX_HEAP>>);
65
66pub type MaxDistanceHeap<I> = DistanceHeap<I, true>;
67pub type MinDistanceHeap<I> = DistanceHeap<I, false>;
68
69fn non_zero_capacity(capacity: usize) -> usize {
70    if capacity == 0 {
71        if cfg!(debug_assertions) {
72            panic!("unexpected 0 capacity");
73        } else {
74            1
75        }
76    } else {
77        capacity
78    }
79}
80
81impl<I, const MAX_HEAP: bool> DistanceHeap<I, MAX_HEAP> {
82    pub fn with_capacity(capacity: usize) -> Self {
83        let capacity = non_zero_capacity(capacity);
84        Self(BinaryHeap::with_capacity(capacity))
85    }
86
87    pub fn push(&mut self, distance: VectorDistance, item: I) {
88        self.0.push(HeapNode { distance, item });
89    }
90
91    pub fn top(&self) -> Option<(VectorDistance, &I)> {
92        self.0.peek().map(|node| (node.distance, &node.item))
93    }
94
95    pub fn pop(&mut self) -> Option<(VectorDistance, I)> {
96        self.0.pop().map(|node| (node.distance, node.item))
97    }
98}
99
100pub struct BoundedNearest<I> {
101    heap: MaxDistanceHeap<I>,
102    capacity: usize,
103}
104
105impl<I> BoundedNearest<I> {
106    pub fn new(capacity: usize) -> Self {
107        Self {
108            heap: DistanceHeap(BinaryHeap::with_capacity(capacity)),
109            capacity,
110        }
111    }
112
113    pub fn furthest(&self) -> Option<(VectorDistance, &I)> {
114        self.heap.top()
115    }
116
117    pub fn insert(
118        &mut self,
119        distance: VectorDistance,
120        get_item: impl FnOnce() -> I,
121    ) -> Option<(VectorDistance, I)> {
122        if self.heap.0.len() >= self.capacity {
123            // we have restricted that `capacity` cannot be 0, so
124            // when `heap.len() >= capacity`, the heap must be non-empty.
125            let mut top = self.heap.0.peek_mut().expect("non-empty");
126            if top.distance > distance {
127                let prev_node = replace(
128                    &mut *top,
129                    HeapNode {
130                        distance,
131                        item: get_item(),
132                    },
133                );
134                Some((prev_node.distance, prev_node.item))
135            } else {
136                None
137            }
138        } else {
139            self.heap.0.push(HeapNode {
140                distance,
141                item: get_item(),
142            });
143            None
144        }
145    }
146
147    pub fn collect(self) -> Vec<I> {
148        self.collect_with(|_, item| item, None)
149    }
150
151    pub fn collect_with<O>(
152        mut self,
153        mut f: impl FnMut(VectorDistance, I) -> O,
154        limit: Option<usize>,
155    ) -> Vec<O> {
156        let size = self.heap.0.len();
157        let size = if let Some(limit) = limit {
158            min(size, limit)
159        } else {
160            size
161        };
162        let mut vec = Vec::with_capacity(size);
163        let uninit_slice = vec.spare_capacity_mut();
164        while self.heap.0.len() > size {
165            self.heap.pop();
166        }
167        assert_eq!(size, self.heap.0.len());
168        let mut i = size;
169        // elements are popped from max to min, so we write elements from back to front to ensure that the output is sorted ascendingly.
170        while let Some(node) = self.heap.0.pop() {
171            i -= 1;
172            // safety: `i` is initialized as the size of `self.heap`. It must have decremented for once, and can
173            // decrement for at most `size` time, so it must be that 0 <= i < size
174            unsafe {
175                uninit_slice
176                    .get_unchecked_mut(i)
177                    .write(f(node.distance, node.item));
178            }
179        }
180        assert_eq!(i, 0);
181        // safety: should have write `size` elements to the vector.
182        unsafe { vec.set_len(size) }
183        vec
184    }
185
186    pub fn resize(&mut self, new_capacity: usize) {
187        let new_capacity = non_zero_capacity(new_capacity);
188        self.capacity = new_capacity;
189        while self.heap.0.len() > new_capacity {
190            self.heap.pop();
191        }
192    }
193
194    #[expect(clippy::len_without_is_empty)]
195    pub fn len(&self) -> usize {
196        self.heap.0.len()
197    }
198}
199
200impl<'a, I> IntoIterator for &'a BoundedNearest<I> {
201    type Item = (VectorDistance, &'a I);
202
203    type IntoIter = impl Iterator<Item = Self::Item>;
204
205    fn into_iter(self) -> Self::IntoIter {
206        self.heap.0.iter().map(|node| (node.distance, &node.item))
207    }
208}
209
210impl<I> IntoIterator for BoundedNearest<I> {
211    type Item = (VectorDistance, I);
212
213    type IntoIter = impl Iterator<Item = Self::Item>;
214
215    fn into_iter(self) -> Self::IntoIter {
216        self.heap
217            .0
218            .into_iter()
219            .map(|node| (node.distance, node.item))
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use std::cmp::min;
226
227    use itertools::Itertools;
228    use rand::{Rng, rng};
229    use risingwave_common::array::VectorDistanceType;
230
231    use crate::vector::test_utils::{gen_info, top_n};
232    use crate::vector::utils::BoundedNearest;
233
234    fn test_inner(count: usize, n: usize, limit: Option<usize>) {
235        let input = (0..count)
236            .map(|i| (rng().random::<VectorDistanceType>(), i))
237            .collect_vec();
238        let mut nearest = BoundedNearest::new(n);
239        for &(distance, item) in &input {
240            nearest.insert(distance, || item);
241        }
242        let output = nearest.collect_with(|_, i| gen_info(i), limit);
243        let mut expected_output = input
244            .iter()
245            .map(|(distance, i)| (*distance, gen_info(*i)))
246            .collect_vec();
247        let n = if let Some(limit) = limit {
248            min(n, limit)
249        } else {
250            n
251        };
252        top_n(&mut expected_output, n);
253        let expected_output = expected_output
254            .into_iter()
255            .map(|(_, info)| info)
256            .collect_vec();
257        assert_eq!(output, expected_output);
258    }
259
260    #[test]
261    fn test_not_full_top_n() {
262        test_inner(5, 10, None);
263        test_inner(5, 10, Some(3));
264        test_inner(5, 10, Some(5));
265        test_inner(5, 10, Some(7));
266    }
267
268    #[test]
269    fn test_exact_size_top_n() {
270        test_inner(10, 10, None);
271        test_inner(10, 10, Some(8));
272        test_inner(10, 10, Some(10));
273        test_inner(10, 10, Some(12));
274    }
275
276    #[test]
277    fn test_oversize_top_n() {
278        test_inner(20, 10, None);
279        test_inner(20, 10, Some(8));
280        test_inner(20, 10, Some(10));
281        test_inner(20, 10, Some(15));
282        test_inner(20, 10, Some(20));
283        test_inner(20, 10, Some(25));
284    }
285}