risingwave_storage/vector/
distance.rs1use std::simd::Simd;
16use std::simd::num::SimdFloat;
17
18use crate::vector::{
19 MeasureDistance, MeasureDistanceBuilder, VectorDistance, VectorItem, VectorRef,
20};
21
22#[macro_export]
23macro_rules! for_all_distance_measurement {
24 ($macro:ident $($param:tt)*) => {
25 $macro! {
26 {
27 (L1, $crate::vector::distance::L1Distance),
28 (L2, $crate::vector::distance::L2Distance),
29 (Cosine, $crate::vector::distance::CosineDistance),
30 (InnerProduct, $crate::vector::distance::InnerProductDistance),
31 }
32 $($param)*
33 }
34 };
35}
36
37macro_rules! define_measure {
38 ({
39 $(($distance_name:ident, $_distance_type:ty),)+
40 }) => {
41 pub enum DistanceMeasurement {
42 $($distance_name),+
43 }
44 };
45 () => {
46 for_all_distance_measurement! {define_measure}
47 };
48}
49
50define_measure!();
51
52#[macro_export]
53macro_rules! dispatch_measurement {
54 ({
55 $(($distance_name:ident, $distance_type:ty),)+
56 },
57 $measurement:expr, $type_name:ident, $body:expr) => {
58 match $measurement {
59 $(
60 DistanceMeasurement::$distance_name => {
61 type $type_name = $distance_type;
62 $body
63 }
64 ),+
65 }
66 };
67 ($measurement:expr, $type_name:ident, $body:expr) => {
68 $crate::for_all_distance_measurement! {dispatch_measurement, $measurement, $type_name, $body}
69 };
70}
71
72pub struct L1Distance;
73
74pub struct L1DistanceMeasure<'a>(VectorRef<'a>);
75
76impl MeasureDistanceBuilder for L1Distance {
77 type Measure<'a> = L1DistanceMeasure<'a>;
78
79 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
80 L1DistanceMeasure(target)
81 }
82}
83
84#[cfg_attr(not(test), expect(dead_code))]
85fn l1_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
86 let len = first.0.len();
87 assert_eq!(len, second.0.len());
88 (0..len)
89 .map(|i| {
90 let diff = first.0[i] - second.0[i];
91 diff.abs()
92 })
93 .sum()
94}
95
96fn l1_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
97 faiss::utils::fvec_l1(first.0, second.0)
98}
99
100impl<'a> MeasureDistance for L1DistanceMeasure<'a> {
101 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
102 l1_faiss(self.0, other)
103 }
104}
105
106pub struct L2Distance;
107
108pub struct L2DistanceMeasure<'a>(VectorRef<'a>);
113
114impl MeasureDistanceBuilder for L2Distance {
115 type Measure<'a> = L2DistanceMeasure<'a>;
116
117 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
118 L2DistanceMeasure(target)
119 }
120}
121
122#[cfg_attr(not(test), expect(dead_code))]
123fn l2_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
124 let len = first.0.len();
125 assert_eq!(len, second.0.len());
126 (0..len).map(|i| (first.0[i] - second.0[i]).powi(2)).sum()
127}
128
129fn l2_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
130 faiss::utils::fvec_l2sqr(first.0, second.0)
131}
132
133impl<'a> MeasureDistance for L2DistanceMeasure<'a> {
134 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
135 l2_faiss(self.0, other)
136 }
137}
138
139pub struct CosineDistance;
140pub struct CosineDistanceMeasure<'a> {
141 target: VectorRef<'a>,
142 magnitude: VectorItem,
143}
144
145impl MeasureDistanceBuilder for CosineDistance {
146 type Measure<'a> = CosineDistanceMeasure<'a>;
147
148 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
149 let magnitude = target.magnitude();
150 CosineDistanceMeasure { target, magnitude }
151 }
152}
153
154impl<'a> MeasureDistance for CosineDistanceMeasure<'a> {
155 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
156 let len = self.target.0.len();
157 assert_eq!(len, other.0.len());
158 let magnitude_mul = other.magnitude() * self.magnitude;
159 if magnitude_mul < f32::MIN_POSITIVE {
160 return 1.1;
162 }
163 1.0 - inner_product_faiss(self.target, other) / magnitude_mul
164 }
165}
166
167pub struct InnerProductDistance;
168pub struct InnerProductDistanceMeasure<'a>(VectorRef<'a>);
169
170impl MeasureDistanceBuilder for InnerProductDistance {
171 type Measure<'a> = InnerProductDistanceMeasure<'a>;
172
173 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
174 InnerProductDistanceMeasure(target)
175 }
176}
177
178#[cfg_attr(not(test), expect(dead_code))]
179fn inner_product_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
180 let len = first.0.len();
181 assert_eq!(len, second.0.len());
182 (0..len)
183 .map(|i| first.0[i] * second.0[i])
184 .sum::<VectorItem>()
185}
186
187#[cfg_attr(not(test), expect(dead_code))]
188fn inner_product_simd(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
189 let len = first.0.len();
190 assert_eq!(len, second.0.len());
191 let mut sum = 0.0;
192 let mut start = 0;
193 let mut end = start + 32;
194 while end <= len {
195 let this = Simd::<VectorItem, 32>::from_slice(&first.0[start..end]);
196 let target = Simd::<VectorItem, 32>::from_slice(&second.0[start..end]);
197 sum += (this * target).reduce_sum();
198 start += 32;
199 end += 32;
200 }
201 (start..len)
202 .map(|i| first.0[i] * second.0[i])
203 .sum::<VectorDistance>()
204 + sum
205}
206
207fn inner_product_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
208 faiss::utils::fvec_inner_product(first.0, second.0)
209}
210
211impl<'a> MeasureDistance for InnerProductDistanceMeasure<'a> {
212 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
213 -inner_product_faiss(self.0, other)
214 }
215}
216
217#[cfg(test)]
218mod tests {
219
220 use expect_test::expect;
221
222 use super::*;
223 use crate::vector::test_utils::gen_vector;
224 use crate::vector::{MeasureDistanceBuilder, VectorInner};
225
226 const VECTOR_LEN: usize = 10;
227
228 const VEC1: [f32; VECTOR_LEN] = [
229 0.45742255, 0.04135585, 0.7236407, 0.82355756, 0.837814, 0.09387952, 0.8907283, 0.20203716,
230 0.2039721, 0.7972273,
231 ];
232
233 const VEC2: [f32; VECTOR_LEN] = [
234 0.9755903, 0.42836714, 0.45131344, 0.8602846, 0.61997443, 0.9501612, 0.65076965,
235 0.22877127, 0.97690505, 0.44438475,
236 ];
237
238 const FLOAT_ALLOWED_BIAS: f32 = 1e-5;
239
240 macro_rules! assert_eq_float {
241 ($first:expr, $second:expr) => {
242 assert!(
243 ($first - $second) < FLOAT_ALLOWED_BIAS,
244 "Expected: {}, Actual: {}",
245 $second,
246 $first
247 );
248 };
249 }
250
251 #[test]
252 fn test_distance() {
253 let first_vec = [0.238474, 0.578234];
254 let second_vec = [0.9327183, 0.387495];
255 let [v1_1, v1_2] = first_vec;
256 let [v2_1, v2_2] = second_vec;
257 let first_vec = VectorInner(&first_vec[..]);
258 let second_vec = VectorInner(&second_vec[..]);
259 assert_eq_float!(
260 L1Distance::distance(first_vec, second_vec),
261 (v1_1 - v2_1).abs() + (v1_2 - v2_2).abs()
262 );
263 assert_eq_float!(
264 L2Distance::distance(first_vec, second_vec),
265 (v1_1 - v2_1).powi(2) + (v1_2 - v2_2).powi(2)
266 );
267 assert_eq_float!(
268 CosineDistance::distance(first_vec, second_vec),
269 1.0 - (v1_1 * v2_1 + v1_2 * v2_2)
270 / ((v1_1.powi(2) + v1_2.powi(2)).sqrt() * (v2_1.powi(2) + v2_2.powi(2)).sqrt())
271 );
272 assert_eq_float!(
273 InnerProductDistance::distance(first_vec, second_vec),
274 -(v1_1 * v2_1 + v1_2 * v2_2)
275 );
276 {
277 let v1 = gen_vector(128);
278 let v2 = gen_vector(128);
279 let trivial = inner_product_trivial(v1.to_ref(), v2.to_ref());
280 assert_eq_float!(inner_product_simd(v1.to_ref(), v2.to_ref()), trivial);
281 assert_eq_float!(inner_product_faiss(v1.to_ref(), v2.to_ref()), trivial);
282 assert_eq_float!(
283 l2_trivial(v1.to_ref(), v2.to_ref()),
284 l2_faiss(v1.to_ref(), v2.to_ref())
285 );
286 assert_eq_float!(
287 l1_trivial(v1.to_ref(), v2.to_ref()),
288 l1_faiss(v1.to_ref(), v2.to_ref())
289 );
290 }
291 }
292
293 #[test]
294 fn test_expect_distance() {
295 expect![[r#"
296 3.6808228
297 "#]]
298 .assert_debug_eq(&L1Distance::distance(
299 VectorInner(&VEC1),
300 VectorInner(&VEC2),
301 ));
302 expect![[r#"
303 2.054677
304 "#]]
305 .assert_debug_eq(&L2Distance::distance(
306 VectorInner(&VEC1),
307 VectorInner(&VEC2),
308 ));
309 expect![[r#"
310 0.22848952
311 "#]]
312 .assert_debug_eq(&CosineDistance::distance(
313 VectorInner(&VEC1),
314 VectorInner(&VEC2),
315 ));
316 expect![[r#"
317 -3.2870955
318 "#]]
319 .assert_debug_eq(&InnerProductDistance::distance(
320 VectorInner(&VEC1),
321 VectorInner(&VEC2),
322 ));
323 }
324}