risingwave_storage/vector/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod distance;
16pub use distance::DistanceMeasurement;
17
18pub mod utils;
19
20use std::sync::Arc;
21
22use crate::vector::utils::BoundedNearest;
23
24pub type VectorItem = f32;
25#[derive(Clone, Copy, Debug, PartialEq)]
26pub struct VectorInner<T>(T);
27
28pub type Vector = VectorInner<Arc<[VectorItem]>>;
29pub type VectorRef<'a> = VectorInner<&'a [VectorItem]>;
30pub type VectorMutRef<'a> = VectorInner<&'a mut [VectorItem]>;
31
32impl Vector {
33    pub fn new(inner: &[VectorItem]) -> Self {
34        Self(Arc::from(inner))
35    }
36
37    pub fn to_ref(&self) -> VectorRef<'_> {
38        VectorInner(self.0.as_ref())
39    }
40
41    pub fn clone_from_ref(r: VectorRef<'_>) -> Self {
42        Self(Vec::from(r.0).into())
43    }
44
45    pub fn get_mut(&mut self) -> Option<VectorMutRef<'_>> {
46        Arc::get_mut(&mut self.0).map(VectorInner)
47    }
48
49    /// # Safety
50    ///
51    /// safe under the same condition to [`Arc::get_mut_unchecked`]
52    pub unsafe fn get_mut_unchecked(&mut self) -> VectorMutRef<'_> {
53        // safety: under unsafe function
54        unsafe { VectorInner(Arc::get_mut_unchecked(&mut self.0)) }
55    }
56}
57
58impl<T: AsRef<[VectorItem]>> VectorInner<T> {
59    pub fn magnitude(&self) -> VectorItem {
60        self.0
61            .as_ref()
62            .iter()
63            .map(|item| item.powi(2))
64            .sum::<VectorItem>()
65            .sqrt()
66    }
67
68    pub fn normalized(&self) -> Vector {
69        let slice = self.0.as_ref();
70        let len = slice.len();
71        let mut uninit = Arc::new_uninit_slice(len);
72        // safety: just initialized, must be owned
73        let uninit_mut = unsafe { Arc::get_mut_unchecked(&mut uninit) };
74        let magnitude = self.magnitude();
75        for i in 0..len {
76            // safety: 0 <= i < len
77            unsafe {
78                uninit_mut
79                    .get_unchecked_mut(i)
80                    .write(slice.get_unchecked(i) / magnitude)
81            };
82        }
83        // safety: initialized with len, and have set all item
84        unsafe { VectorInner(uninit.assume_init()) }
85    }
86}
87
88pub type VectorDistance = f32;
89
90pub trait OnNearestItem<O> = for<'i> Fn(VectorRef<'i>, VectorDistance, &'i [u8]) -> O;
91
92pub trait MeasureDistance<'a> {
93    fn measure(&self, other: VectorRef<'_>) -> VectorDistance;
94}
95
96pub trait MeasureDistanceBuilder {
97    type Measure<'a>: MeasureDistance<'a>;
98    fn new(target: VectorRef<'_>) -> Self::Measure<'_>;
99
100    fn distance(target: VectorRef<'_>, other: VectorRef<'_>) -> VectorDistance
101    where
102        Self: Sized,
103    {
104        Self::new(target).measure(other)
105    }
106}
107
108pub struct NearestBuilder<'a, O, M: MeasureDistanceBuilder> {
109    measure: M::Measure<'a>,
110    nearest: BoundedNearest<O>,
111}
112
113impl<'a, O, M: MeasureDistanceBuilder> NearestBuilder<'a, O, M> {
114    pub fn new(target: VectorRef<'a>, n: usize) -> Self {
115        assert!(n > 0);
116        NearestBuilder {
117            measure: M::new(target),
118            nearest: BoundedNearest::new(n),
119        }
120    }
121
122    pub fn add<'b>(
123        &mut self,
124        vecs: impl IntoIterator<Item = (VectorRef<'b>, &'b [u8])> + 'b,
125        on_nearest_item: impl OnNearestItem<O>,
126    ) {
127        for (vec, info) in vecs {
128            let distance = self.measure.measure(vec);
129            self.nearest
130                .insert(distance, || on_nearest_item(vec, distance, info));
131        }
132    }
133
134    pub fn finish(self) -> Vec<O> {
135        self.nearest.collect()
136    }
137}
138
139#[cfg(any(test, feature = "test"))]
140pub mod test_utils {
141    use std::sync::LazyLock;
142
143    use bytes::Bytes;
144    use itertools::Itertools;
145    use parking_lot::Mutex;
146    use rand::prelude::StdRng;
147    use rand::{Rng, SeedableRng};
148
149    use crate::store::Vector;
150    use crate::vector::VectorItem;
151
152    pub fn gen_vector(d: usize) -> Vector {
153        static RNG: LazyLock<Mutex<StdRng>> =
154            LazyLock::new(|| Mutex::new(StdRng::seed_from_u64(233)));
155        Vector::new(
156            &(0..d)
157                .map(|_| RNG.lock().random::<VectorItem>())
158                .collect_vec(),
159        )
160    }
161
162    pub fn gen_info(i: usize) -> Bytes {
163        Bytes::copy_from_slice(i.to_le_bytes().as_slice())
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use std::cmp::min;
170
171    use bytes::Bytes;
172    use itertools::Itertools;
173
174    use crate::vector::distance::L2Distance;
175    use crate::vector::test_utils::{gen_info, gen_vector};
176    use crate::vector::{MeasureDistanceBuilder, NearestBuilder, Vector, VectorDistance};
177
178    fn gen_random_input(count: usize) -> Vec<(Vector, Bytes)> {
179        (0..count).map(|i| (gen_vector(10), gen_info(i))).collect()
180    }
181
182    #[test]
183    fn test_empty_top_n() {
184        let vec = gen_vector(10);
185        let builder = NearestBuilder::<'_, (), L2Distance>::new(vec.to_ref(), 10);
186        assert!(builder.finish().is_empty());
187    }
188
189    fn top_n<O>(input: &mut Vec<(VectorDistance, O)>, n: usize) {
190        input.sort_by(|(first_distance, _), (second_distance, _)| {
191            first_distance.total_cmp(second_distance)
192        });
193        let n = min(n, input.len());
194        input.resize_with(n, || unreachable!());
195    }
196
197    fn test_inner(count: usize, n: usize) {
198        let input = gen_random_input(count);
199        let vec = gen_vector(10);
200        let mut builder = NearestBuilder::<'_, _, L2Distance>::new(vec.to_ref(), 10);
201        builder.add(
202            input.iter().map(|(v, b)| (v.to_ref(), b.as_ref())),
203            |_, d, b| (d, Bytes::copy_from_slice(b)),
204        );
205        let output = builder.finish();
206        let mut expected_output = input
207            .into_iter()
208            .map(|(v, b)| (L2Distance::distance(vec.to_ref(), v.to_ref()), b))
209            .collect_vec();
210        top_n(&mut expected_output, n);
211        assert_eq!(output, expected_output);
212    }
213
214    #[test]
215    fn test_not_full_top_n() {
216        test_inner(5, 10);
217    }
218
219    #[test]
220    fn test_exact_size_top_n() {
221        test_inner(10, 10);
222    }
223
224    #[test]
225    fn test_oversize_top_n() {
226        test_inner(20, 10);
227    }
228}