risingwave_storage/vector/
distance.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::simd::Simd;
16use std::simd::num::SimdFloat;
17
18use crate::vector::{
19    MeasureDistance, MeasureDistanceBuilder, VectorDistance, VectorItem, VectorRef,
20};
21
22#[macro_export]
23macro_rules! for_all_distance_measurement {
24    ($macro:ident $($param:tt)*) => {
25        $macro! {
26            {
27                (L1, $crate::vector::distance::L1Distance),
28                (L2, $crate::vector::distance::L2Distance),
29                (Cosine, $crate::vector::distance::CosineDistance),
30                (InnerProduct, $crate::vector::distance::InnerProductDistance),
31            }
32            $($param)*
33        }
34    };
35}
36
37macro_rules! define_measure {
38    ({
39        $(($distance_name:ident, $_distance_type:ty),)+
40    }) => {
41        pub enum DistanceMeasurement {
42            $($distance_name),+
43        }
44    };
45    () => {
46        for_all_distance_measurement! {define_measure}
47    };
48}
49
50define_measure!();
51
52#[macro_export]
53macro_rules! dispatch_measurement {
54    ({
55        $(($distance_name:ident, $distance_type:ty),)+
56    },
57    $measurement:expr, $type_name:ident, $body:expr) => {
58        match $measurement {
59            $(
60                DistanceMeasurement::$distance_name => {
61                    type $type_name = $distance_type;
62                    $body
63                }
64            ),+
65        }
66    };
67    ($measurement:expr, $type_name:ident, $body:expr) => {
68        $crate::for_all_distance_measurement! {dispatch_measurement, $measurement, $type_name, $body}
69    };
70}
71
72pub struct L1Distance;
73
74pub struct L1DistanceMeasure<'a>(VectorRef<'a>);
75
76impl MeasureDistanceBuilder for L1Distance {
77    type Measure<'a> = L1DistanceMeasure<'a>;
78
79    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
80        L1DistanceMeasure(target)
81    }
82}
83
84#[cfg_attr(not(test), expect(dead_code))]
85fn l1_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
86    let len = first.0.len();
87    assert_eq!(len, second.0.len());
88    (0..len)
89        .map(|i| {
90            let diff = first.0[i] - second.0[i];
91            diff.abs()
92        })
93        .sum()
94}
95
96fn l1_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
97    faiss::utils::fvec_l1(first.0, second.0)
98}
99
100impl<'a> MeasureDistance for L1DistanceMeasure<'a> {
101    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
102        l1_faiss(self.0, other)
103    }
104}
105
106pub struct L2Distance;
107
108/// Measure the l2 distance
109///
110/// In this implementation, we don't take the square root to avoid unnecessary computation, because
111/// we only want comparison rather than the actual distance.
112pub struct L2DistanceMeasure<'a>(VectorRef<'a>);
113
114impl MeasureDistanceBuilder for L2Distance {
115    type Measure<'a> = L2DistanceMeasure<'a>;
116
117    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
118        L2DistanceMeasure(target)
119    }
120}
121
122#[cfg_attr(not(test), expect(dead_code))]
123fn l2_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
124    let len = first.0.len();
125    assert_eq!(len, second.0.len());
126    (0..len).map(|i| (first.0[i] - second.0[i]).powi(2)).sum()
127}
128
129fn l2_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
130    faiss::utils::fvec_l2sqr(first.0, second.0)
131}
132
133impl<'a> MeasureDistance for L2DistanceMeasure<'a> {
134    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
135        l2_faiss(self.0, other)
136    }
137}
138
139pub struct CosineDistance;
140pub struct CosineDistanceMeasure<'a> {
141    target: VectorRef<'a>,
142    magnitude: VectorItem,
143}
144
145impl MeasureDistanceBuilder for CosineDistance {
146    type Measure<'a> = CosineDistanceMeasure<'a>;
147
148    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
149        let magnitude = target.magnitude();
150        CosineDistanceMeasure { target, magnitude }
151    }
152}
153
154impl<'a> MeasureDistance for CosineDistanceMeasure<'a> {
155    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
156        let len = self.target.0.len();
157        assert_eq!(len, other.0.len());
158        let magnitude_mul = other.magnitude() * self.magnitude;
159        if magnitude_mul < f32::MIN_POSITIVE {
160            // If either vector is zero, the distance is the further 1.1
161            return 1.1;
162        }
163        1.0 - inner_product_faiss(self.target, other) / magnitude_mul
164    }
165}
166
167pub struct InnerProductDistance;
168pub struct InnerProductDistanceMeasure<'a>(VectorRef<'a>);
169
170impl MeasureDistanceBuilder for InnerProductDistance {
171    type Measure<'a> = InnerProductDistanceMeasure<'a>;
172
173    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
174        InnerProductDistanceMeasure(target)
175    }
176}
177
178#[cfg_attr(not(test), expect(dead_code))]
179fn inner_product_trivial(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
180    let len = first.0.len();
181    assert_eq!(len, second.0.len());
182    (0..len)
183        .map(|i| first.0[i] * second.0[i])
184        .sum::<VectorItem>()
185}
186
187#[cfg_attr(not(test), expect(dead_code))]
188fn inner_product_simd(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
189    let len = first.0.len();
190    assert_eq!(len, second.0.len());
191    let mut sum = 0.0;
192    let mut start = 0;
193    let mut end = start + 32;
194    while end <= len {
195        let this = Simd::<VectorItem, 32>::from_slice(&first.0[start..end]);
196        let target = Simd::<VectorItem, 32>::from_slice(&second.0[start..end]);
197        sum += (this * target).reduce_sum();
198        start += 32;
199        end += 32;
200    }
201    (start..len)
202        .map(|i| first.0[i] * second.0[i])
203        .sum::<VectorDistance>()
204        + sum
205}
206
207fn inner_product_faiss(first: VectorRef<'_>, second: VectorRef<'_>) -> VectorDistance {
208    faiss::utils::fvec_inner_product(first.0, second.0)
209}
210
211impl<'a> MeasureDistance for InnerProductDistanceMeasure<'a> {
212    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
213        -inner_product_faiss(self.0, other)
214    }
215}
216
217#[cfg(test)]
218mod tests {
219
220    use expect_test::expect;
221
222    use super::*;
223    use crate::vector::test_utils::gen_vector;
224    use crate::vector::{MeasureDistanceBuilder, VectorInner};
225
226    const VECTOR_LEN: usize = 10;
227
228    const VEC1: [f32; VECTOR_LEN] = [
229        0.45742255, 0.04135585, 0.7236407, 0.82355756, 0.837814, 0.09387952, 0.8907283, 0.20203716,
230        0.2039721, 0.7972273,
231    ];
232
233    const VEC2: [f32; VECTOR_LEN] = [
234        0.9755903, 0.42836714, 0.45131344, 0.8602846, 0.61997443, 0.9501612, 0.65076965,
235        0.22877127, 0.97690505, 0.44438475,
236    ];
237
238    const FLOAT_ALLOWED_BIAS: f32 = 1e-5;
239
240    macro_rules! assert_eq_float {
241        ($first:expr, $second:expr) => {
242            assert!(
243                ($first - $second) < FLOAT_ALLOWED_BIAS,
244                "Expected: {}, Actual: {}",
245                $second,
246                $first
247            );
248        };
249    }
250
251    #[test]
252    fn test_distance() {
253        let first_vec = [0.238474, 0.578234];
254        let second_vec = [0.9327183, 0.387495];
255        let [v1_1, v1_2] = first_vec;
256        let [v2_1, v2_2] = second_vec;
257        let first_vec = VectorInner(&first_vec[..]);
258        let second_vec = VectorInner(&second_vec[..]);
259        assert_eq_float!(
260            L1Distance::distance(first_vec, second_vec),
261            (v1_1 - v2_1).abs() + (v1_2 - v2_2).abs()
262        );
263        assert_eq_float!(
264            L2Distance::distance(first_vec, second_vec),
265            (v1_1 - v2_1).powi(2) + (v1_2 - v2_2).powi(2)
266        );
267        assert_eq_float!(
268            CosineDistance::distance(first_vec, second_vec),
269            1.0 - (v1_1 * v2_1 + v1_2 * v2_2)
270                / ((v1_1.powi(2) + v1_2.powi(2)).sqrt() * (v2_1.powi(2) + v2_2.powi(2)).sqrt())
271        );
272        assert_eq_float!(
273            InnerProductDistance::distance(first_vec, second_vec),
274            -(v1_1 * v2_1 + v1_2 * v2_2)
275        );
276        {
277            let v1 = gen_vector(128);
278            let v2 = gen_vector(128);
279            let trivial = inner_product_trivial(v1.to_ref(), v2.to_ref());
280            assert_eq_float!(inner_product_simd(v1.to_ref(), v2.to_ref()), trivial);
281            assert_eq_float!(inner_product_faiss(v1.to_ref(), v2.to_ref()), trivial);
282            assert_eq_float!(
283                l2_trivial(v1.to_ref(), v2.to_ref()),
284                l2_faiss(v1.to_ref(), v2.to_ref())
285            );
286            assert_eq_float!(
287                l1_trivial(v1.to_ref(), v2.to_ref()),
288                l1_faiss(v1.to_ref(), v2.to_ref())
289            );
290        }
291    }
292
293    #[test]
294    fn test_expect_distance() {
295        expect![[r#"
296            3.6808228
297        "#]]
298        .assert_debug_eq(&L1Distance::distance(
299            VectorInner(&VEC1),
300            VectorInner(&VEC2),
301        ));
302        expect![[r#"
303            2.054677
304        "#]]
305        .assert_debug_eq(&L2Distance::distance(
306            VectorInner(&VEC1),
307            VectorInner(&VEC2),
308        ));
309        expect![[r#"
310            0.22848952
311        "#]]
312        .assert_debug_eq(&CosineDistance::distance(
313            VectorInner(&VEC1),
314            VectorInner(&VEC2),
315        ));
316        expect![[r#"
317            -3.2870955
318        "#]]
319        .assert_debug_eq(&InnerProductDistance::distance(
320            VectorInner(&VEC1),
321            VectorInner(&VEC2),
322        ));
323    }
324}