risingwave_common/vector/
distance.rs1use std::simd::Simd;
16use std::simd::num::SimdFloat;
17
18use risingwave_pb::common::PbDistanceType;
19
20use crate::array::{VectorDistanceType as VectorDistance, VectorDistanceType};
21use crate::types::VectorRef;
22use crate::vector::{MeasureDistance, MeasureDistanceBuilder};
23
24#[macro_export]
25macro_rules! for_all_distance_measurement {
26 ($macro:ident $($param:tt)*) => {
27 $macro! {
28 {
29 (L1, $crate::vector::distance::L1Distance),
30 (L2Sqr, $crate::vector::distance::L2SqrDistance),
31 (Cosine, $crate::vector::distance::CosineDistance),
32 (InnerProduct, $crate::vector::distance::InnerProductDistance),
33 }
34 $($param)*
35 }
36 };
37}
38
39macro_rules! define_measure {
40 ({
41 $(($distance_name:ident, $_distance_type:ty),)+
42 }) => {
43 #[derive(Clone, Copy)]
44 pub enum DistanceMeasurement {
45 $($distance_name),+
46 }
47
48 impl From<PbDistanceType> for DistanceMeasurement {
49 fn from(value: PbDistanceType) -> Self {
50 match value {
51 PbDistanceType::Unspecified => {
52 unreachable!()
53 }
54 $(
55 PbDistanceType::$distance_name => {
56 DistanceMeasurement::$distance_name
57 }
58 ),+
59 }
60 }
61 }
62 };
63 () => {
64 for_all_distance_measurement! {define_measure}
65 };
66}
67
68define_measure!();
69
70#[macro_export]
71macro_rules! dispatch_distance_measurement {
72 ({
73 $(($distance_name:ident, $distance_type:ty),)+
74 },
75 $measurement:expr, $type_name:ident, $body:expr) => {
76 match $measurement {
77 $(
78 $crate::vector::distance::DistanceMeasurement::$distance_name => {
79 type $type_name = $distance_type;
80 $body
81 }
82 ),+
83 }
84 };
85 ($measurement:expr, $type_name:ident, $body:expr) => {
86 $crate::for_all_distance_measurement! {dispatch_distance_measurement, $measurement, $type_name, $body}
87 };
88}
89
90pub struct L1Distance;
91
92pub struct L1DistanceMeasure<'a>(VectorRef<'a>);
93
94impl MeasureDistanceBuilder for L1Distance {
95 type Measure<'a> = L1DistanceMeasure<'a>;
96
97 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
98 L1DistanceMeasure(target)
99 }
100}
101
102#[cfg_attr(not(test), expect(dead_code))]
103fn l1_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
104 let first = first.as_slice();
105 let second = second.as_slice();
106 let len = first.len();
107 assert_eq!(len, second.len());
108 (0..len)
109 .map(|i| {
110 let diff = first[i].0 - second[i].0;
111 diff.abs() as VectorDistance
112 })
113 .sum()
114}
115
116pub fn l1_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
117 faiss::utils::fvec_l1(first.as_raw_slice(), second.as_raw_slice()) as VectorDistance
118}
119
120impl<'a> MeasureDistance for L1DistanceMeasure<'a> {
121 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
122 l1_faiss(self.0, other)
123 }
124}
125
126pub struct L2SqrDistance;
127
128pub struct L2SqrDistanceMeasure<'a>(VectorRef<'a>);
133
134impl MeasureDistanceBuilder for L2SqrDistance {
135 type Measure<'a> = L2SqrDistanceMeasure<'a>;
136
137 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
138 L2SqrDistanceMeasure(target)
139 }
140}
141
142#[cfg_attr(not(test), expect(dead_code))]
143fn l2sqr_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
144 let first = first.as_slice();
145 let second = second.as_slice();
146 let len = first.len();
147 assert_eq!(len, second.len());
148 (0..len)
149 .map(|i| ((first[i].0 - second[i].0) as VectorDistance).powi(2))
150 .sum()
151}
152
153pub fn l2sqr_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
154 faiss::utils::fvec_l2sqr(first.as_raw_slice(), second.as_raw_slice()) as VectorDistance
155}
156
157impl<'a> MeasureDistance for L2SqrDistanceMeasure<'a> {
158 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
159 l2sqr_faiss(self.0, other)
160 }
161}
162
163pub struct CosineDistance;
164pub struct CosineDistanceMeasure<'a> {
165 target: VectorRef<'a>,
166 l2_norm: f32,
167}
168
169impl MeasureDistanceBuilder for CosineDistance {
170 type Measure<'a> = CosineDistanceMeasure<'a>;
171
172 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
173 let l2_norm = target.l2_norm();
174 CosineDistanceMeasure { target, l2_norm }
175 }
176}
177
178pub fn cosine_distance(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistanceType {
179 cosine_distance_inner(first, first.l2_norm(), second).unwrap_or(VectorDistanceType::NAN)
180}
181
182fn cosine_distance_inner(
183 first: VectorRef<'_>,
184 first_l2_norm: f32,
185 second: VectorRef<'_>,
186) -> Option<VectorDistanceType> {
187 assert_eq!(first.dimension(), second.dimension());
188 let l2_norm_mul = second.l2_norm() * first_l2_norm;
189 if l2_norm_mul < f32::MIN_POSITIVE {
190 None
191 } else {
192 Some(1.0 - inner_product_faiss(first, second) / l2_norm_mul as VectorDistance)
193 }
194}
195
196impl<'a> MeasureDistance for CosineDistanceMeasure<'a> {
197 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
198 cosine_distance_inner(self.target, self.l2_norm, other).unwrap_or({
199 1.1
201 })
202 }
203}
204
205pub struct InnerProductDistance;
206pub struct InnerProductDistanceMeasure<'a>(VectorRef<'a>);
207
208impl MeasureDistanceBuilder for InnerProductDistance {
209 type Measure<'a> = InnerProductDistanceMeasure<'a>;
210
211 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
212 InnerProductDistanceMeasure(target)
213 }
214}
215
216#[cfg_attr(not(test), expect(dead_code))]
217fn inner_product_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
218 let first = first.as_slice();
219 let second = second.as_slice();
220 let len = first.len();
221 assert_eq!(len, second.len());
222 (0..len)
223 .map(|i| (first[i].0 * second[i].0) as VectorDistance)
224 .sum::<VectorDistance>()
225}
226
227#[cfg_attr(not(test), expect(dead_code))]
228fn inner_product_simd(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
229 let first = first.as_raw_slice();
230 let second = second.as_raw_slice();
231 let len = first.len();
232 assert_eq!(len, second.len());
233 let mut sum = 0.0;
234 let mut start = 0;
235 let mut end = start + 32;
236 while end <= len {
237 let this = Simd::<f32, 32>::from_slice(&first[start..end]);
238 let target = Simd::<f32, 32>::from_slice(&second[start..end]);
239 sum += (this * target).reduce_sum() as VectorDistance;
240 start += 32;
241 end += 32;
242 }
243 (start..len)
244 .map(|i| (first[i] * second[i]) as VectorDistance)
245 .sum::<VectorDistance>()
246 + sum
247}
248
249pub fn inner_product_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
250 faiss::utils::fvec_inner_product(first.as_raw_slice(), second.as_raw_slice()) as VectorDistance
251}
252
253impl<'a> MeasureDistance for InnerProductDistanceMeasure<'a> {
254 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
255 -inner_product_faiss(self.0, other)
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::array::VectorVal;
263 use crate::test_utils::rand_array::gen_vector_for_test;
264
265 const VECTOR_LEN: usize = 10;
266
267 const VEC1: [f32; VECTOR_LEN] = [
268 0.45742255, 0.04135585, 0.7236407, 0.82355756, 0.837814, 0.09387952, 0.8907283, 0.20203716,
269 0.2039721, 0.7972273,
270 ];
271
272 const VEC2: [f32; VECTOR_LEN] = [
273 0.9755903, 0.42836714, 0.45131344, 0.8602846, 0.61997443, 0.9501612, 0.65076965,
274 0.22877127, 0.97690505, 0.44438475,
275 ];
276
277 const FLOAT_ABS_EPS: f32 = 2e-5;
278 const FLOAT_REL_EPS: f32 = 1e-6;
279
280 macro_rules! assert_eq_float {
281 ($first:expr, $second:expr) => {{
282 let a: f32 = $first as _;
283 let b: f32 = $second as _;
284 let diff = (a - b).abs();
285 let tol = FLOAT_ABS_EPS.max(FLOAT_REL_EPS * a.abs().max(b.abs()));
286 assert!(
287 diff <= tol,
288 "Expected: {}, Actual: {}, |Δ|={} > tol={}",
289 b,
290 a,
291 diff,
292 tol
293 );
294 }};
295 }
296
297 #[test]
298 fn test_distance() {
299 let first_vec = [0.238474_f32, 0.578234];
300 let second_vec = [0.9327183_f32, 0.387495];
301 let [v1_1, v1_2] = first_vec;
302 let [v2_1, v2_2] = second_vec;
303 let first_vec = VectorVal {
304 inner: first_vec.map(Into::into).to_vec().into_boxed_slice(),
305 };
306 let second_vec = VectorVal {
307 inner: second_vec.map(Into::into).to_vec().into_boxed_slice(),
308 };
309 let first_vec = first_vec.to_ref();
310 let second_vec = second_vec.to_ref();
311 assert_eq_float!(
312 L1Distance::distance(first_vec, second_vec),
313 (v1_1 - v2_1).abs() + (v1_2 - v2_2).abs()
314 );
315 assert_eq_float!(
316 L2SqrDistance::distance(first_vec, second_vec),
317 (v1_1 - v2_1).powi(2) + (v1_2 - v2_2).powi(2)
318 );
319 assert_eq_float!(
320 CosineDistance::distance(first_vec, second_vec),
321 1.0 - (v1_1 * v2_1 + v1_2 * v2_2)
322 / ((v1_1.powi(2) + v1_2.powi(2)).sqrt() * (v2_1.powi(2) + v2_2.powi(2)).sqrt())
323 );
324 assert_eq_float!(
325 InnerProductDistance::distance(first_vec, second_vec),
326 -(v1_1 * v2_1 + v1_2 * v2_2)
327 );
328 {
329 let v1 = gen_vector_for_test(128);
330 let v2 = gen_vector_for_test(128);
331 let trivial = inner_product_trivial(v1.to_ref(), v2.to_ref());
332 assert_eq_float!(inner_product_simd(v1.to_ref(), v2.to_ref()), trivial);
333 assert_eq_float!(inner_product_faiss(v1.to_ref(), v2.to_ref()), trivial);
334 assert_eq_float!(
335 l2sqr_trivial(v1.to_ref(), v2.to_ref()),
336 l2sqr_faiss(v1.to_ref(), v2.to_ref())
337 );
338 assert_eq_float!(
339 l1_trivial(v1.to_ref(), v2.to_ref()),
340 l1_faiss(v1.to_ref(), v2.to_ref())
341 );
342 }
343 }
344
345 #[test]
346 fn test_expect_distance() {
347 assert_eq_float!(
348 3.6808228,
349 L1Distance::distance(
350 VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
351 VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
352 )
353 );
354 assert_eq_float!(
355 2.054677,
356 L2SqrDistance::distance(
357 VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
358 VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
359 )
360 );
361 assert_eq_float!(
362 0.22848952,
363 CosineDistance::distance(
364 VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
365 VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
366 )
367 );
368 assert_eq_float!(
369 -3.2870955,
370 InnerProductDistance::distance(
371 VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
372 VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
373 )
374 );
375 }
376}