risingwave_storage/vector/
mod.rs1pub mod distance;
16pub mod hnsw;
17pub use distance::DistanceMeasurement;
18
19pub mod utils;
20
21use std::sync::Arc;
22
23use crate::vector::utils::BoundedNearest;
24
25pub type VectorItem = f32;
26#[derive(Clone, Copy, Debug, PartialEq)]
27pub struct VectorInner<T>(T);
28
29pub type Vector = VectorInner<Arc<[VectorItem]>>;
30pub type VectorRef<'a> = VectorInner<&'a [VectorItem]>;
31pub type VectorMutRef<'a> = VectorInner<&'a mut [VectorItem]>;
32
33impl Vector {
34 pub fn new(inner: &[VectorItem]) -> Self {
35 Self(Arc::from(inner))
36 }
37
38 pub fn to_ref(&self) -> VectorRef<'_> {
39 VectorInner(self.0.as_ref())
40 }
41
42 pub fn clone_from_ref(r: VectorRef<'_>) -> Self {
43 Self(Vec::from(r.0).into())
44 }
45
46 pub fn get_mut(&mut self) -> Option<VectorMutRef<'_>> {
47 Arc::get_mut(&mut self.0).map(VectorInner)
48 }
49
50 pub unsafe fn get_mut_unchecked(&mut self) -> VectorMutRef<'_> {
54 unsafe { VectorInner(Arc::get_mut_unchecked(&mut self.0)) }
56 }
57}
58
59impl<'a> VectorRef<'a> {
60 pub fn from_slice(slice: &'a [VectorItem]) -> Self {
61 VectorInner(slice)
62 }
63}
64
65#[cfg_attr(not(test), expect(dead_code))]
66fn l2_norm_trivial(vec: &VectorInner<impl AsRef<[VectorItem]>>) -> VectorItem {
67 vec.0
68 .as_ref()
69 .iter()
70 .map(|item| item.powi(2))
71 .sum::<VectorItem>()
72 .sqrt()
73}
74
75fn l2_norm_faiss(vec: &VectorInner<impl AsRef<[VectorItem]>>) -> VectorItem {
76 faiss::utils::fvec_norm_l2sqr(vec.0.as_ref()).sqrt()
77}
78
79impl<T: AsRef<[VectorItem]>> VectorInner<T> {
80 pub fn dimension(&self) -> usize {
81 self.0.as_ref().len()
82 }
83
84 pub fn as_slice(&self) -> &[VectorItem] {
85 self.0.as_ref()
86 }
87
88 pub fn magnitude(&self) -> VectorItem {
89 l2_norm_faiss(self)
90 }
91
92 pub fn normalized(&self) -> Vector {
93 let slice = self.0.as_ref();
94 let len = slice.len();
95 let mut uninit = Arc::new_uninit_slice(len);
96 let uninit_mut = unsafe { Arc::get_mut_unchecked(&mut uninit) };
98 let magnitude = self.magnitude();
99 for i in 0..len {
100 unsafe {
102 uninit_mut
103 .get_unchecked_mut(i)
104 .write(slice.get_unchecked(i) / magnitude)
105 };
106 }
107 unsafe { VectorInner(uninit.assume_init()) }
109 }
110}
111
112pub type VectorDistance = f32;
113
114pub trait OnNearestItem<O> = for<'i> Fn(VectorRef<'i>, VectorDistance, &'i [u8]) -> O;
115
116pub trait MeasureDistance {
117 fn measure(&self, other: VectorRef<'_>) -> VectorDistance;
118}
119
120pub trait MeasureDistanceBuilder {
121 type Measure<'a>: MeasureDistance + 'a;
122 fn new(target: VectorRef<'_>) -> Self::Measure<'_>;
123
124 fn distance(target: VectorRef<'_>, other: VectorRef<'_>) -> VectorDistance
125 where
126 Self: Sized,
127 {
128 Self::new(target).measure(other)
129 }
130}
131
132pub struct NearestBuilder<'a, O, M: MeasureDistanceBuilder> {
133 measure: M::Measure<'a>,
134 nearest: BoundedNearest<O>,
135}
136
137impl<'a, O, M: MeasureDistanceBuilder> NearestBuilder<'a, O, M> {
138 pub fn new(target: VectorRef<'a>, n: usize) -> Self {
139 assert!(n > 0);
140 NearestBuilder {
141 measure: M::new(target),
142 nearest: BoundedNearest::new(n),
143 }
144 }
145
146 pub fn add<'b>(
147 &mut self,
148 vecs: impl IntoIterator<Item = (VectorRef<'b>, &'b [u8])> + 'b,
149 on_nearest_item: impl OnNearestItem<O>,
150 ) {
151 for (vec, info) in vecs {
152 let distance = self.measure.measure(vec);
153 self.nearest
154 .insert(distance, || on_nearest_item(vec, distance, info));
155 }
156 }
157
158 pub fn finish(self) -> Vec<O> {
159 self.nearest.collect()
160 }
161}
162
163#[cfg(any(test, feature = "test"))]
164pub mod test_utils {
165 use std::cmp::min;
166 use std::sync::LazyLock;
167
168 use bytes::Bytes;
169 use itertools::Itertools;
170 use parking_lot::Mutex;
171 use rand::prelude::StdRng;
172 use rand::{Rng, SeedableRng};
173
174 use crate::store::Vector;
175 use crate::vector::{VectorDistance, VectorItem};
176
177 pub fn gen_vector(d: usize) -> Vector {
178 static RNG: LazyLock<Mutex<StdRng>> =
179 LazyLock::new(|| Mutex::new(StdRng::seed_from_u64(233)));
180 Vector::new(
181 &(0..d)
182 .map(|_| RNG.lock().random::<VectorItem>())
183 .collect_vec(),
184 )
185 }
186
187 pub fn gen_info(i: usize) -> Bytes {
188 Bytes::copy_from_slice(i.to_le_bytes().as_slice())
189 }
190
191 pub fn top_n<O>(input: &mut Vec<(VectorDistance, O)>, n: usize) {
192 input.sort_by(|(first_distance, _), (second_distance, _)| {
193 first_distance.total_cmp(second_distance)
194 });
195 let n = min(n, input.len());
196 input.resize_with(n, || unreachable!());
197 }
198}
199
200#[cfg(test)]
201mod tests {
202
203 use bytes::Bytes;
204 use itertools::Itertools;
205
206 use crate::vector::distance::L2Distance;
207 use crate::vector::test_utils::{gen_info, gen_vector, top_n};
208 use crate::vector::{
209 MeasureDistanceBuilder, NearestBuilder, Vector, VectorInner, l2_norm_faiss, l2_norm_trivial,
210 };
211
212 fn gen_random_input(count: usize) -> Vec<(Vector, Bytes)> {
213 (0..count).map(|i| (gen_vector(10), gen_info(i))).collect()
214 }
215
216 #[test]
217 fn test_vector() {
218 let vec = [0.238474, 0.578234];
219 let [v1_1, v1_2] = vec;
220 let vec = VectorInner(&vec[..]);
221
222 assert_eq!(vec.magnitude(), (v1_1.powi(2) + v1_2.powi(2)).sqrt());
223 assert_eq!(l2_norm_faiss(&vec), l2_norm_trivial(&vec));
224
225 let mut normalized_vec = Vector::new(&[v1_1 / vec.magnitude(), v1_2 / vec.magnitude()]);
226 assert_eq!(vec.normalized(), normalized_vec);
227 assert!(normalized_vec.get_mut().is_some());
228 let mut normalized_vec_clone = normalized_vec.clone();
229 assert!(normalized_vec.get_mut().is_none());
230 assert!(normalized_vec_clone.get_mut().is_none());
231 drop(normalized_vec);
232 assert!(normalized_vec_clone.get_mut().is_some());
233 }
234
235 #[test]
236 fn test_empty_top_n() {
237 let vec = gen_vector(10);
238 let builder = NearestBuilder::<'_, (), L2Distance>::new(vec.to_ref(), 10);
239 assert!(builder.finish().is_empty());
240 }
241
242 fn test_inner(count: usize, n: usize) {
243 let input = gen_random_input(count);
244 let vec = gen_vector(10);
245 let mut builder = NearestBuilder::<'_, _, L2Distance>::new(vec.to_ref(), 10);
246 builder.add(
247 input.iter().map(|(v, b)| (v.to_ref(), b.as_ref())),
248 |_, d, b| (d, Bytes::copy_from_slice(b)),
249 );
250 let output = builder.finish();
251 let mut expected_output = input
252 .into_iter()
253 .map(|(v, b)| (L2Distance::distance(vec.to_ref(), v.to_ref()), b))
254 .collect_vec();
255 top_n(&mut expected_output, n);
256 assert_eq!(output, expected_output);
257 }
258
259 #[test]
260 fn test_not_full_top_n() {
261 test_inner(5, 10);
262 }
263
264 #[test]
265 fn test_exact_size_top_n() {
266 test_inner(10, 10);
267 }
268
269 #[test]
270 fn test_oversize_top_n() {
271 test_inner(20, 10);
272 }
273}