risingwave_common/vector/
distance.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::simd::Simd;
16use std::simd::num::SimdFloat;
17
18use risingwave_pb::common::PbDistanceType;
19
20use crate::array::{VectorDistanceType as VectorDistance, VectorDistanceType};
21use crate::types::VectorRef;
22use crate::vector::{MeasureDistance, MeasureDistanceBuilder};
23
24#[macro_export]
25macro_rules! for_all_distance_measurement {
26    ($macro:ident $($param:tt)*) => {
27        $macro! {
28            {
29                (L1, $crate::vector::distance::L1Distance),
30                (L2Sqr, $crate::vector::distance::L2SqrDistance),
31                (Cosine, $crate::vector::distance::CosineDistance),
32                (InnerProduct, $crate::vector::distance::InnerProductDistance),
33            }
34            $($param)*
35        }
36    };
37}
38
39macro_rules! define_measure {
40    ({
41        $(($distance_name:ident, $_distance_type:ty),)+
42    }) => {
43        #[derive(Clone, Copy)]
44        pub enum DistanceMeasurement {
45            $($distance_name),+
46        }
47
48        impl From<PbDistanceType> for DistanceMeasurement {
49            fn from(value: PbDistanceType) -> Self {
50                match value {
51                    PbDistanceType::Unspecified => {
52                        unreachable!()
53                    }
54                    $(
55                        PbDistanceType::$distance_name => {
56                            DistanceMeasurement::$distance_name
57                        }
58                    ),+
59                }
60            }
61        }
62    };
63    () => {
64        for_all_distance_measurement! {define_measure}
65    };
66}
67
68define_measure!();
69
70#[macro_export]
71macro_rules! dispatch_distance_measurement {
72    ({
73        $(($distance_name:ident, $distance_type:ty),)+
74    },
75    $measurement:expr, $type_name:ident, $body:expr) => {
76        match $measurement {
77            $(
78                $crate::vector::distance::DistanceMeasurement::$distance_name => {
79                    type $type_name = $distance_type;
80                    $body
81                }
82            ),+
83        }
84    };
85    ($measurement:expr, $type_name:ident, $body:expr) => {
86        $crate::for_all_distance_measurement! {dispatch_distance_measurement, $measurement, $type_name, $body}
87    };
88}
89
90pub struct L1Distance;
91
92pub struct L1DistanceMeasure<'a>(VectorRef<'a>);
93
94impl MeasureDistanceBuilder for L1Distance {
95    type Measure<'a> = L1DistanceMeasure<'a>;
96
97    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
98        L1DistanceMeasure(target)
99    }
100}
101
102#[cfg_attr(not(test), expect(dead_code))]
103fn l1_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
104    let first = first.as_slice();
105    let second = second.as_slice();
106    let len = first.len();
107    assert_eq!(len, second.len());
108    (0..len)
109        .map(|i| {
110            let diff = first[i].0 - second[i].0;
111            diff.abs() as VectorDistance
112        })
113        .sum()
114}
115
116pub fn l1_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
117    faiss::utils::fvec_l1(first.as_raw_slice(), second.as_raw_slice()) as VectorDistance
118}
119
120impl<'a> MeasureDistance for L1DistanceMeasure<'a> {
121    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
122        l1_faiss(self.0, other)
123    }
124}
125
126pub struct L2SqrDistance;
127
128/// Measure the l2 distance
129///
130/// In this implementation, we don't take the square root to avoid unnecessary computation, because
131/// we only want comparison rather than the actual distance.
132pub struct L2SqrDistanceMeasure<'a>(VectorRef<'a>);
133
134impl MeasureDistanceBuilder for L2SqrDistance {
135    type Measure<'a> = L2SqrDistanceMeasure<'a>;
136
137    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
138        L2SqrDistanceMeasure(target)
139    }
140}
141
142#[cfg_attr(not(test), expect(dead_code))]
143fn l2sqr_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
144    let first = first.as_slice();
145    let second = second.as_slice();
146    let len = first.len();
147    assert_eq!(len, second.len());
148    (0..len)
149        .map(|i| ((first[i].0 - second[i].0) as VectorDistance).powi(2))
150        .sum()
151}
152
153pub fn l2sqr_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
154    faiss::utils::fvec_l2sqr(first.as_raw_slice(), second.as_raw_slice()) as VectorDistance
155}
156
157impl<'a> MeasureDistance for L2SqrDistanceMeasure<'a> {
158    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
159        l2sqr_faiss(self.0, other)
160    }
161}
162
163pub struct CosineDistance;
164pub struct CosineDistanceMeasure<'a> {
165    target: VectorRef<'a>,
166    l2_norm: f32,
167}
168
169impl MeasureDistanceBuilder for CosineDistance {
170    type Measure<'a> = CosineDistanceMeasure<'a>;
171
172    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
173        let l2_norm = target.l2_norm();
174        CosineDistanceMeasure { target, l2_norm }
175    }
176}
177
178pub fn cosine_distance(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistanceType {
179    cosine_distance_inner(first, first.l2_norm(), second).unwrap_or(VectorDistanceType::NAN)
180}
181
182fn cosine_distance_inner(
183    first: VectorRef<'_>,
184    first_l2_norm: f32,
185    second: VectorRef<'_>,
186) -> Option<VectorDistanceType> {
187    assert_eq!(first.dimension(), second.dimension());
188    let l2_norm_mul = second.l2_norm() * first_l2_norm;
189    if l2_norm_mul < f32::MIN_POSITIVE {
190        None
191    } else {
192        Some(1.0 - inner_product_faiss(first, second) / l2_norm_mul as VectorDistance)
193    }
194}
195
196impl<'a> MeasureDistance for CosineDistanceMeasure<'a> {
197    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
198        cosine_distance_inner(self.target, self.l2_norm, other).unwrap_or({
199            // If either vector is zero, the distance is the further 1.1
200            1.1
201        })
202    }
203}
204
205pub struct InnerProductDistance;
206pub struct InnerProductDistanceMeasure<'a>(VectorRef<'a>);
207
208impl MeasureDistanceBuilder for InnerProductDistance {
209    type Measure<'a> = InnerProductDistanceMeasure<'a>;
210
211    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
212        InnerProductDistanceMeasure(target)
213    }
214}
215
216#[cfg_attr(not(test), expect(dead_code))]
217fn inner_product_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
218    let first = first.as_slice();
219    let second = second.as_slice();
220    let len = first.len();
221    assert_eq!(len, second.len());
222    (0..len)
223        .map(|i| (first[i].0 * second[i].0) as VectorDistance)
224        .sum::<VectorDistance>()
225}
226
227#[cfg_attr(not(test), expect(dead_code))]
228fn inner_product_simd(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
229    let first = first.as_raw_slice();
230    let second = second.as_raw_slice();
231    let len = first.len();
232    assert_eq!(len, second.len());
233    let mut sum = 0.0;
234    let mut start = 0;
235    let mut end = start + 32;
236    while end <= len {
237        let this = Simd::<f32, 32>::from_slice(&first[start..end]);
238        let target = Simd::<f32, 32>::from_slice(&second[start..end]);
239        sum += (this * target).reduce_sum() as VectorDistance;
240        start += 32;
241        end += 32;
242    }
243    (start..len)
244        .map(|i| (first[i] * second[i]) as VectorDistance)
245        .sum::<VectorDistance>()
246        + sum
247}
248
249pub fn inner_product_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
250    faiss::utils::fvec_inner_product(first.as_raw_slice(), second.as_raw_slice()) as VectorDistance
251}
252
253impl<'a> MeasureDistance for InnerProductDistanceMeasure<'a> {
254    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
255        -inner_product_faiss(self.0, other)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::array::VectorVal;
263    use crate::test_utils::rand_array::gen_vector_for_test;
264
265    const VECTOR_LEN: usize = 10;
266
267    const VEC1: [f32; VECTOR_LEN] = [
268        0.45742255, 0.04135585, 0.7236407, 0.82355756, 0.837814, 0.09387952, 0.8907283, 0.20203716,
269        0.2039721, 0.7972273,
270    ];
271
272    const VEC2: [f32; VECTOR_LEN] = [
273        0.9755903, 0.42836714, 0.45131344, 0.8602846, 0.61997443, 0.9501612, 0.65076965,
274        0.22877127, 0.97690505, 0.44438475,
275    ];
276
277    const FLOAT_ABS_EPS: f32 = 2e-5;
278    const FLOAT_REL_EPS: f32 = 1e-6;
279
280    macro_rules! assert_eq_float {
281        ($first:expr, $second:expr) => {{
282            let a: f32 = $first as _;
283            let b: f32 = $second as _;
284            let diff = (a - b).abs();
285            let tol = FLOAT_ABS_EPS.max(FLOAT_REL_EPS * a.abs().max(b.abs()));
286            assert!(
287                diff <= tol,
288                "Expected: {}, Actual: {}, |Δ|={} > tol={}",
289                b,
290                a,
291                diff,
292                tol
293            );
294        }};
295    }
296
297    #[test]
298    fn test_distance() {
299        let first_vec = [0.238474_f32, 0.578234];
300        let second_vec = [0.9327183_f32, 0.387495];
301        let [v1_1, v1_2] = first_vec;
302        let [v2_1, v2_2] = second_vec;
303        let first_vec = VectorVal {
304            inner: first_vec.map(Into::into).to_vec().into_boxed_slice(),
305        };
306        let second_vec = VectorVal {
307            inner: second_vec.map(Into::into).to_vec().into_boxed_slice(),
308        };
309        let first_vec = first_vec.to_ref();
310        let second_vec = second_vec.to_ref();
311        assert_eq_float!(
312            L1Distance::distance(first_vec, second_vec),
313            (v1_1 - v2_1).abs() + (v1_2 - v2_2).abs()
314        );
315        assert_eq_float!(
316            L2SqrDistance::distance(first_vec, second_vec),
317            (v1_1 - v2_1).powi(2) + (v1_2 - v2_2).powi(2)
318        );
319        assert_eq_float!(
320            CosineDistance::distance(first_vec, second_vec),
321            1.0 - (v1_1 * v2_1 + v1_2 * v2_2)
322                / ((v1_1.powi(2) + v1_2.powi(2)).sqrt() * (v2_1.powi(2) + v2_2.powi(2)).sqrt())
323        );
324        assert_eq_float!(
325            InnerProductDistance::distance(first_vec, second_vec),
326            -(v1_1 * v2_1 + v1_2 * v2_2)
327        );
328        {
329            let v1 = gen_vector_for_test(128);
330            let v2 = gen_vector_for_test(128);
331            let trivial = inner_product_trivial(v1.to_ref(), v2.to_ref());
332            assert_eq_float!(inner_product_simd(v1.to_ref(), v2.to_ref()), trivial);
333            assert_eq_float!(inner_product_faiss(v1.to_ref(), v2.to_ref()), trivial);
334            assert_eq_float!(
335                l2sqr_trivial(v1.to_ref(), v2.to_ref()),
336                l2sqr_faiss(v1.to_ref(), v2.to_ref())
337            );
338            assert_eq_float!(
339                l1_trivial(v1.to_ref(), v2.to_ref()),
340                l1_faiss(v1.to_ref(), v2.to_ref())
341            );
342        }
343    }
344
345    #[test]
346    fn test_expect_distance() {
347        assert_eq_float!(
348            3.6808228,
349            L1Distance::distance(
350                VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
351                VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
352            )
353        );
354        assert_eq_float!(
355            2.054677,
356            L2SqrDistance::distance(
357                VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
358                VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
359            )
360        );
361        assert_eq_float!(
362            0.22848952,
363            CosineDistance::distance(
364                VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
365                VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
366            )
367        );
368        assert_eq_float!(
369            -3.2870955,
370            InnerProductDistance::distance(
371                VectorRef::from_slice_unchecked(&VEC1.map(Into::into)[..]),
372                VectorRef::from_slice_unchecked(&VEC2.map(Into::into)[..]),
373            )
374        );
375    }
376}