risingwave_storage/vector/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod distance;
16pub mod hnsw;
17pub use distance::DistanceMeasurement;
18
19pub mod utils;
20
21use std::sync::Arc;
22
23use crate::vector::utils::BoundedNearest;
24
25pub type VectorItem = f32;
26#[derive(Clone, Copy, Debug, PartialEq)]
27pub struct VectorInner<T>(T);
28
29pub type Vector = VectorInner<Arc<[VectorItem]>>;
30pub type VectorRef<'a> = VectorInner<&'a [VectorItem]>;
31pub type VectorMutRef<'a> = VectorInner<&'a mut [VectorItem]>;
32
33impl Vector {
34    pub fn new(inner: &[VectorItem]) -> Self {
35        Self(Arc::from(inner))
36    }
37
38    pub fn to_ref(&self) -> VectorRef<'_> {
39        VectorInner(self.0.as_ref())
40    }
41
42    pub fn clone_from_ref(r: VectorRef<'_>) -> Self {
43        Self(Vec::from(r.0).into())
44    }
45
46    pub fn get_mut(&mut self) -> Option<VectorMutRef<'_>> {
47        Arc::get_mut(&mut self.0).map(VectorInner)
48    }
49
50    /// # Safety
51    ///
52    /// safe under the same condition to [`Arc::get_mut_unchecked`]
53    pub unsafe fn get_mut_unchecked(&mut self) -> VectorMutRef<'_> {
54        // safety: under unsafe function
55        unsafe { VectorInner(Arc::get_mut_unchecked(&mut self.0)) }
56    }
57}
58
59impl<'a> VectorRef<'a> {
60    pub fn from_slice(slice: &'a [VectorItem]) -> Self {
61        VectorInner(slice)
62    }
63}
64
65#[cfg_attr(not(test), expect(dead_code))]
66fn l2_norm_trivial(vec: &VectorInner<impl AsRef<[VectorItem]>>) -> VectorItem {
67    vec.0
68        .as_ref()
69        .iter()
70        .map(|item| item.powi(2))
71        .sum::<VectorItem>()
72        .sqrt()
73}
74
75fn l2_norm_faiss(vec: &VectorInner<impl AsRef<[VectorItem]>>) -> VectorItem {
76    faiss::utils::fvec_norm_l2sqr(vec.0.as_ref()).sqrt()
77}
78
79impl<T: AsRef<[VectorItem]>> VectorInner<T> {
80    pub fn dimension(&self) -> usize {
81        self.0.as_ref().len()
82    }
83
84    pub fn as_slice(&self) -> &[VectorItem] {
85        self.0.as_ref()
86    }
87
88    pub fn magnitude(&self) -> VectorItem {
89        l2_norm_faiss(self)
90    }
91
92    pub fn normalized(&self) -> Vector {
93        let slice = self.0.as_ref();
94        let len = slice.len();
95        let mut uninit = Arc::new_uninit_slice(len);
96        // safety: just initialized, must be owned
97        let uninit_mut = unsafe { Arc::get_mut_unchecked(&mut uninit) };
98        let magnitude = self.magnitude();
99        for i in 0..len {
100            // safety: 0 <= i < len
101            unsafe {
102                uninit_mut
103                    .get_unchecked_mut(i)
104                    .write(slice.get_unchecked(i) / magnitude)
105            };
106        }
107        // safety: initialized with len, and have set all item
108        unsafe { VectorInner(uninit.assume_init()) }
109    }
110}
111
112pub type VectorDistance = f32;
113
114pub trait OnNearestItem<O> = for<'i> Fn(VectorRef<'i>, VectorDistance, &'i [u8]) -> O;
115
116pub trait MeasureDistance {
117    fn measure(&self, other: VectorRef<'_>) -> VectorDistance;
118}
119
120pub trait MeasureDistanceBuilder {
121    type Measure<'a>: MeasureDistance + 'a;
122    fn new(target: VectorRef<'_>) -> Self::Measure<'_>;
123
124    fn distance(target: VectorRef<'_>, other: VectorRef<'_>) -> VectorDistance
125    where
126        Self: Sized,
127    {
128        Self::new(target).measure(other)
129    }
130}
131
132pub struct NearestBuilder<'a, O, M: MeasureDistanceBuilder> {
133    measure: M::Measure<'a>,
134    nearest: BoundedNearest<O>,
135}
136
137impl<'a, O, M: MeasureDistanceBuilder> NearestBuilder<'a, O, M> {
138    pub fn new(target: VectorRef<'a>, n: usize) -> Self {
139        assert!(n > 0);
140        NearestBuilder {
141            measure: M::new(target),
142            nearest: BoundedNearest::new(n),
143        }
144    }
145
146    pub fn add<'b>(
147        &mut self,
148        vecs: impl IntoIterator<Item = (VectorRef<'b>, &'b [u8])> + 'b,
149        on_nearest_item: impl OnNearestItem<O>,
150    ) {
151        for (vec, info) in vecs {
152            let distance = self.measure.measure(vec);
153            self.nearest
154                .insert(distance, || on_nearest_item(vec, distance, info));
155        }
156    }
157
158    pub fn finish(self) -> Vec<O> {
159        self.nearest.collect()
160    }
161}
162
163#[cfg(any(test, feature = "test"))]
164pub mod test_utils {
165    use std::cmp::min;
166    use std::sync::LazyLock;
167
168    use bytes::Bytes;
169    use itertools::Itertools;
170    use parking_lot::Mutex;
171    use rand::prelude::StdRng;
172    use rand::{Rng, SeedableRng};
173
174    use crate::store::Vector;
175    use crate::vector::{VectorDistance, VectorItem};
176
177    pub fn gen_vector(d: usize) -> Vector {
178        static RNG: LazyLock<Mutex<StdRng>> =
179            LazyLock::new(|| Mutex::new(StdRng::seed_from_u64(233)));
180        Vector::new(
181            &(0..d)
182                .map(|_| RNG.lock().random::<VectorItem>())
183                .collect_vec(),
184        )
185    }
186
187    pub fn gen_info(i: usize) -> Bytes {
188        Bytes::copy_from_slice(i.to_le_bytes().as_slice())
189    }
190
191    pub fn top_n<O>(input: &mut Vec<(VectorDistance, O)>, n: usize) {
192        input.sort_by(|(first_distance, _), (second_distance, _)| {
193            first_distance.total_cmp(second_distance)
194        });
195        let n = min(n, input.len());
196        input.resize_with(n, || unreachable!());
197    }
198}
199
200#[cfg(test)]
201mod tests {
202
203    use bytes::Bytes;
204    use itertools::Itertools;
205
206    use crate::vector::distance::L2Distance;
207    use crate::vector::test_utils::{gen_info, gen_vector, top_n};
208    use crate::vector::{
209        MeasureDistanceBuilder, NearestBuilder, Vector, VectorInner, l2_norm_faiss, l2_norm_trivial,
210    };
211
212    fn gen_random_input(count: usize) -> Vec<(Vector, Bytes)> {
213        (0..count).map(|i| (gen_vector(10), gen_info(i))).collect()
214    }
215
216    #[test]
217    fn test_vector() {
218        let vec = [0.238474, 0.578234];
219        let [v1_1, v1_2] = vec;
220        let vec = VectorInner(&vec[..]);
221
222        assert_eq!(vec.magnitude(), (v1_1.powi(2) + v1_2.powi(2)).sqrt());
223        assert_eq!(l2_norm_faiss(&vec), l2_norm_trivial(&vec));
224
225        let mut normalized_vec = Vector::new(&[v1_1 / vec.magnitude(), v1_2 / vec.magnitude()]);
226        assert_eq!(vec.normalized(), normalized_vec);
227        assert!(normalized_vec.get_mut().is_some());
228        let mut normalized_vec_clone = normalized_vec.clone();
229        assert!(normalized_vec.get_mut().is_none());
230        assert!(normalized_vec_clone.get_mut().is_none());
231        drop(normalized_vec);
232        assert!(normalized_vec_clone.get_mut().is_some());
233    }
234
235    #[test]
236    fn test_empty_top_n() {
237        let vec = gen_vector(10);
238        let builder = NearestBuilder::<'_, (), L2Distance>::new(vec.to_ref(), 10);
239        assert!(builder.finish().is_empty());
240    }
241
242    fn test_inner(count: usize, n: usize) {
243        let input = gen_random_input(count);
244        let vec = gen_vector(10);
245        let mut builder = NearestBuilder::<'_, _, L2Distance>::new(vec.to_ref(), 10);
246        builder.add(
247            input.iter().map(|(v, b)| (v.to_ref(), b.as_ref())),
248            |_, d, b| (d, Bytes::copy_from_slice(b)),
249        );
250        let output = builder.finish();
251        let mut expected_output = input
252            .into_iter()
253            .map(|(v, b)| (L2Distance::distance(vec.to_ref(), v.to_ref()), b))
254            .collect_vec();
255        top_n(&mut expected_output, n);
256        assert_eq!(output, expected_output);
257    }
258
259    #[test]
260    fn test_not_full_top_n() {
261        test_inner(5, 10);
262    }
263
264    #[test]
265    fn test_exact_size_top_n() {
266        test_inner(10, 10);
267    }
268
269    #[test]
270    fn test_oversize_top_n() {
271        test_inner(20, 10);
272    }
273}