risingwave_storage/vector/
mod.rs1pub mod distance;
16pub use distance::DistanceMeasurement;
17
18pub mod utils;
19
20use std::sync::Arc;
21
22use crate::vector::utils::BoundedNearest;
23
24pub type VectorItem = f32;
25#[derive(Clone, Copy, Debug, PartialEq)]
26pub struct VectorInner<T>(T);
27
28pub type Vector = VectorInner<Arc<[VectorItem]>>;
29pub type VectorRef<'a> = VectorInner<&'a [VectorItem]>;
30pub type VectorMutRef<'a> = VectorInner<&'a mut [VectorItem]>;
31
32impl Vector {
33 pub fn new(inner: &[VectorItem]) -> Self {
34 Self(Arc::from(inner))
35 }
36
37 pub fn to_ref(&self) -> VectorRef<'_> {
38 VectorInner(self.0.as_ref())
39 }
40
41 pub fn clone_from_ref(r: VectorRef<'_>) -> Self {
42 Self(Vec::from(r.0).into())
43 }
44
45 pub fn get_mut(&mut self) -> Option<VectorMutRef<'_>> {
46 Arc::get_mut(&mut self.0).map(VectorInner)
47 }
48
49 pub unsafe fn get_mut_unchecked(&mut self) -> VectorMutRef<'_> {
53 unsafe { VectorInner(Arc::get_mut_unchecked(&mut self.0)) }
55 }
56}
57
58impl<T: AsRef<[VectorItem]>> VectorInner<T> {
59 pub fn magnitude(&self) -> VectorItem {
60 self.0
61 .as_ref()
62 .iter()
63 .map(|item| item.powi(2))
64 .sum::<VectorItem>()
65 .sqrt()
66 }
67
68 pub fn normalized(&self) -> Vector {
69 let slice = self.0.as_ref();
70 let len = slice.len();
71 let mut uninit = Arc::new_uninit_slice(len);
72 let uninit_mut = unsafe { Arc::get_mut_unchecked(&mut uninit) };
74 let magnitude = self.magnitude();
75 for i in 0..len {
76 unsafe {
78 uninit_mut
79 .get_unchecked_mut(i)
80 .write(slice.get_unchecked(i) / magnitude)
81 };
82 }
83 unsafe { VectorInner(uninit.assume_init()) }
85 }
86}
87
88pub type VectorDistance = f32;
89
90pub trait OnNearestItem<O> = for<'i> Fn(VectorRef<'i>, VectorDistance, &'i [u8]) -> O;
91
92pub trait MeasureDistance<'a> {
93 fn measure(&self, other: VectorRef<'_>) -> VectorDistance;
94}
95
96pub trait MeasureDistanceBuilder {
97 type Measure<'a>: MeasureDistance<'a>;
98 fn new(target: VectorRef<'_>) -> Self::Measure<'_>;
99
100 fn distance(target: VectorRef<'_>, other: VectorRef<'_>) -> VectorDistance
101 where
102 Self: Sized,
103 {
104 Self::new(target).measure(other)
105 }
106}
107
108pub struct NearestBuilder<'a, O, M: MeasureDistanceBuilder> {
109 measure: M::Measure<'a>,
110 nearest: BoundedNearest<O>,
111}
112
113impl<'a, O, M: MeasureDistanceBuilder> NearestBuilder<'a, O, M> {
114 pub fn new(target: VectorRef<'a>, n: usize) -> Self {
115 assert!(n > 0);
116 NearestBuilder {
117 measure: M::new(target),
118 nearest: BoundedNearest::new(n),
119 }
120 }
121
122 pub fn add<'b>(
123 &mut self,
124 vecs: impl IntoIterator<Item = (VectorRef<'b>, &'b [u8])> + 'b,
125 on_nearest_item: impl OnNearestItem<O>,
126 ) {
127 for (vec, info) in vecs {
128 let distance = self.measure.measure(vec);
129 self.nearest
130 .insert(distance, || on_nearest_item(vec, distance, info));
131 }
132 }
133
134 pub fn finish(self) -> Vec<O> {
135 self.nearest.collect()
136 }
137}
138
139#[cfg(any(test, feature = "test"))]
140pub mod test_utils {
141 use std::sync::LazyLock;
142
143 use bytes::Bytes;
144 use itertools::Itertools;
145 use parking_lot::Mutex;
146 use rand::prelude::StdRng;
147 use rand::{Rng, SeedableRng};
148
149 use crate::store::Vector;
150 use crate::vector::VectorItem;
151
152 pub fn gen_vector(d: usize) -> Vector {
153 static RNG: LazyLock<Mutex<StdRng>> =
154 LazyLock::new(|| Mutex::new(StdRng::seed_from_u64(233)));
155 Vector::new(
156 &(0..d)
157 .map(|_| RNG.lock().random::<VectorItem>())
158 .collect_vec(),
159 )
160 }
161
162 pub fn gen_info(i: usize) -> Bytes {
163 Bytes::copy_from_slice(i.to_le_bytes().as_slice())
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use std::cmp::min;
170
171 use bytes::Bytes;
172 use itertools::Itertools;
173
174 use crate::vector::distance::L2Distance;
175 use crate::vector::test_utils::{gen_info, gen_vector};
176 use crate::vector::{MeasureDistanceBuilder, NearestBuilder, Vector, VectorDistance};
177
178 fn gen_random_input(count: usize) -> Vec<(Vector, Bytes)> {
179 (0..count).map(|i| (gen_vector(10), gen_info(i))).collect()
180 }
181
182 #[test]
183 fn test_empty_top_n() {
184 let vec = gen_vector(10);
185 let builder = NearestBuilder::<'_, (), L2Distance>::new(vec.to_ref(), 10);
186 assert!(builder.finish().is_empty());
187 }
188
189 fn top_n<O>(input: &mut Vec<(VectorDistance, O)>, n: usize) {
190 input.sort_by(|(first_distance, _), (second_distance, _)| {
191 first_distance.total_cmp(second_distance)
192 });
193 let n = min(n, input.len());
194 input.resize_with(n, || unreachable!());
195 }
196
197 fn test_inner(count: usize, n: usize) {
198 let input = gen_random_input(count);
199 let vec = gen_vector(10);
200 let mut builder = NearestBuilder::<'_, _, L2Distance>::new(vec.to_ref(), 10);
201 builder.add(
202 input.iter().map(|(v, b)| (v.to_ref(), b.as_ref())),
203 |_, d, b| (d, Bytes::copy_from_slice(b)),
204 );
205 let output = builder.finish();
206 let mut expected_output = input
207 .into_iter()
208 .map(|(v, b)| (L2Distance::distance(vec.to_ref(), v.to_ref()), b))
209 .collect_vec();
210 top_n(&mut expected_output, n);
211 assert_eq!(output, expected_output);
212 }
213
214 #[test]
215 fn test_not_full_top_n() {
216 test_inner(5, 10);
217 }
218
219 #[test]
220 fn test_exact_size_top_n() {
221 test_inner(10, 10);
222 }
223
224 #[test]
225 fn test_oversize_top_n() {
226 test_inner(20, 10);
227 }
228}