risingwave_storage/vector/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::BinaryHeap;
17use std::mem::replace;
18
19use crate::vector::VectorDistance;
20
21pub(super) fn compare_distance(first: VectorDistance, second: VectorDistance) -> Ordering {
22    first
23        .partial_cmp(&second)
24        .unwrap_or_else(|| panic!("failed to compare distance {} and {}.", first, second))
25}
26
27fn compare_distance_on_heap<const MAX_HEAP: bool>(
28    first: VectorDistance,
29    second: VectorDistance,
30) -> Ordering {
31    let (first, second) = if MAX_HEAP {
32        (first, second)
33    } else {
34        (second, first)
35    };
36    compare_distance(first, second)
37}
38
39pub(super) struct HeapNode<I, const MAX_HEAP: bool> {
40    distance: VectorDistance,
41    pub(super) item: I,
42}
43
44impl<I, const MAX_HEAP: bool> PartialEq for HeapNode<I, MAX_HEAP> {
45    fn eq(&self, other: &Self) -> bool {
46        self.distance.eq(&other.distance)
47    }
48}
49
50impl<I, const MAX_HEAP: bool> Eq for HeapNode<I, MAX_HEAP> {}
51
52impl<I, const MAX_HEAP: bool> PartialOrd for HeapNode<I, MAX_HEAP> {
53    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl<I, const MAX_HEAP: bool> Ord for HeapNode<I, MAX_HEAP> {
59    fn cmp(&self, other: &Self) -> Ordering {
60        compare_distance_on_heap::<MAX_HEAP>(self.distance, other.distance)
61    }
62}
63
64pub struct DistanceHeap<I, const MAX_HEAP: bool>(BinaryHeap<HeapNode<I, MAX_HEAP>>);
65
66pub type MaxDistanceHeap<I> = DistanceHeap<I, true>;
67pub type MinDistanceHeap<I> = DistanceHeap<I, false>;
68
69impl<I, const MAX_HEAP: bool> DistanceHeap<I, MAX_HEAP> {
70    pub fn with_capacity(capacity: usize) -> Self {
71        Self(BinaryHeap::with_capacity(capacity))
72    }
73
74    pub fn push(&mut self, distance: VectorDistance, item: I) {
75        self.0.push(HeapNode { distance, item });
76    }
77
78    pub fn top(&self) -> Option<(VectorDistance, &I)> {
79        self.0.peek().map(|node| (node.distance, &node.item))
80    }
81
82    pub fn pop(&mut self) -> Option<(VectorDistance, I)> {
83        self.0.pop().map(|node| (node.distance, node.item))
84    }
85}
86
87pub struct BoundedNearest<I> {
88    heap: MaxDistanceHeap<I>,
89    capacity: usize,
90}
91
92impl<I> BoundedNearest<I> {
93    pub fn new(capacity: usize) -> Self {
94        Self {
95            heap: DistanceHeap(BinaryHeap::with_capacity(capacity)),
96            capacity,
97        }
98    }
99
100    pub fn furthest(&self) -> Option<(VectorDistance, &I)> {
101        self.heap.top()
102    }
103
104    pub fn insert(
105        &mut self,
106        distance: VectorDistance,
107        get_item: impl FnOnce() -> I,
108    ) -> Option<(VectorDistance, I)> {
109        if self.heap.0.len() >= self.capacity {
110            let mut top = self.heap.0.peek_mut().expect("non-empty");
111            if top.distance > distance {
112                let prev_node = replace(
113                    &mut *top,
114                    HeapNode {
115                        distance,
116                        item: get_item(),
117                    },
118                );
119                Some((prev_node.distance, prev_node.item))
120            } else {
121                None
122            }
123        } else {
124            self.heap.0.push(HeapNode {
125                distance,
126                item: get_item(),
127            });
128            None
129        }
130    }
131
132    pub fn collect(mut self) -> Vec<I> {
133        let size = self.heap.0.len();
134        let mut vec = Vec::with_capacity(size);
135        let uninit_slice = vec.spare_capacity_mut();
136        let mut i = size;
137        // elements are popped from max to min, so we write elements from back to front to ensure that the output is sorted ascendingly.
138        while let Some(node) = self.heap.0.pop() {
139            i -= 1;
140            // safety: `i` is initialized as the size of `self.heap`. It must have decremented for once, and can
141            // decrement for at most `size` time, so it must be that 0 <= i < size
142            unsafe {
143                uninit_slice.get_unchecked_mut(i).write(node.item);
144            }
145        }
146        assert_eq!(i, 0);
147        // safety: should have write `size` elements to the vector.
148        unsafe { vec.set_len(size) }
149        vec
150    }
151
152    pub fn resize(&mut self, new_capacity: usize) {
153        self.capacity = new_capacity;
154        while self.heap.0.len() > new_capacity {
155            self.heap.pop();
156        }
157    }
158}
159
160impl<'a, I> IntoIterator for &'a BoundedNearest<I> {
161    type Item = (VectorDistance, &'a I);
162
163    type IntoIter = impl Iterator<Item = Self::Item>;
164
165    fn into_iter(self) -> Self::IntoIter {
166        self.heap.0.iter().map(|node| (node.distance, &node.item))
167    }
168}