risingwave_storage/vector/
distance.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::vector::{
16    MeasureDistance, MeasureDistanceBuilder, VectorDistance, VectorItem, VectorRef,
17};
18
19#[macro_export]
20macro_rules! for_all_distance_measurement {
21    ($macro:ident $($param:tt)*) => {
22        $macro! {
23            {
24                (L1, $crate::vector::distance::L1Distance),
25                (L2, $crate::vector::distance::L2Distance),
26                (Cosine, $crate::vector::distance::CosineDistance),
27                (InnerProduct, $crate::vector::distance::InnerProductDistance),
28            }
29            $($param)*
30        }
31    };
32}
33
34macro_rules! define_measure {
35    ({
36        $(($distance_name:ident, $_distance_type:ty),)+
37    }) => {
38        pub enum DistanceMeasurement {
39            $($distance_name),+
40        }
41    };
42    () => {
43        for_all_distance_measurement! {define_measure}
44    };
45}
46
47define_measure!();
48
49#[macro_export]
50macro_rules! dispatch_measurement {
51    ({
52        $(($distance_name:ident, $distance_type:ty),)+
53    },
54    $measurement:expr, $type_name:ident, $body:expr) => {
55        match $measurement {
56            $(
57                DistanceMeasurement::$distance_name => {
58                    type $type_name = $distance_type;
59                    $body
60                }
61            ),+
62        }
63    };
64    ($measurement:expr, $type_name:ident, $body:expr) => {
65        $crate::for_all_distance_measurement! {dispatch_measurement, $measurement, $type_name, $body}
66    };
67}
68
69pub struct L1Distance;
70
71pub struct L1DistanceMeasure<'a>(VectorRef<'a>);
72
73impl MeasureDistanceBuilder for L1Distance {
74    type Measure<'a> = L1DistanceMeasure<'a>;
75
76    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
77        L1DistanceMeasure(target)
78    }
79}
80
81impl<'a> MeasureDistance<'a> for L1DistanceMeasure<'a> {
82    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
83        // TODO: use some library with simd support
84        let len = self.0.0.len();
85        assert_eq!(len, other.0.len());
86        // In this implementation, we don't take the square root to avoid unnecessary computation, because
87        // we only want comparison rather than the actual distance.
88        (0..len)
89            .map(|i| {
90                let diff = self.0.0[i] - other.0[i];
91                diff.abs()
92            })
93            .sum()
94    }
95}
96
97pub struct L2Distance;
98
99pub struct L2DistanceMeasure<'a>(VectorRef<'a>);
100
101impl MeasureDistanceBuilder for L2Distance {
102    type Measure<'a> = L2DistanceMeasure<'a>;
103
104    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
105        L2DistanceMeasure(target)
106    }
107}
108
109impl<'a> MeasureDistance<'a> for L2DistanceMeasure<'a> {
110    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
111        // TODO: use some library with simd support
112        let len = self.0.0.len();
113        assert_eq!(len, other.0.len());
114        // In this implementation, we don't take the square root to avoid unnecessary computation, because
115        // we only want comparison rather than the actual distance.
116        (0..len).map(|i| (self.0.0[i] - other.0[i]).powi(2)).sum()
117    }
118}
119
120pub struct CosineDistance;
121pub struct CosineDistanceMeasure<'a> {
122    target: VectorRef<'a>,
123    magnitude: VectorItem,
124}
125
126impl MeasureDistanceBuilder for CosineDistance {
127    type Measure<'a> = CosineDistanceMeasure<'a>;
128
129    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
130        let magnitude = target.magnitude();
131        CosineDistanceMeasure { target, magnitude }
132    }
133}
134
135impl<'a> MeasureDistance<'a> for CosineDistanceMeasure<'a> {
136    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
137        // TODO: use some library with simd support
138        let len = self.target.0.len();
139        assert_eq!(len, other.0.len());
140        let magnitude_mul = other.magnitude() * self.magnitude;
141        1.0 - (0..len)
142            .map(|i| self.target.0[i] * other.0[i] / magnitude_mul)
143            .sum::<VectorDistance>()
144    }
145}
146
147pub struct InnerProductDistance;
148pub struct InnerProductDistanceMeasure<'a>(VectorRef<'a>);
149
150impl MeasureDistanceBuilder for InnerProductDistance {
151    type Measure<'a> = InnerProductDistanceMeasure<'a>;
152
153    fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
154        InnerProductDistanceMeasure(target)
155    }
156}
157
158impl<'a> MeasureDistance<'a> for InnerProductDistanceMeasure<'a> {
159    fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
160        // TODO: use some library with simd support
161        let len = self.0.0.len();
162        assert_eq!(len, other.0.len());
163        -(0..len)
164            .map(|i| self.0.0[i] * other.0[i])
165            .sum::<VectorDistance>()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171
172    use expect_test::expect;
173
174    use super::*;
175    use crate::vector::{MeasureDistanceBuilder, Vector, VectorInner};
176
177    const VECTOR_LEN: usize = 10;
178
179    const VEC1: [f32; VECTOR_LEN] = [
180        0.45742255, 0.04135585, 0.7236407, 0.82355756, 0.837814, 0.09387952, 0.8907283, 0.20203716,
181        0.2039721, 0.7972273,
182    ];
183
184    const VEC2: [f32; VECTOR_LEN] = [
185        0.9755903, 0.42836714, 0.45131344, 0.8602846, 0.61997443, 0.9501612, 0.65076965,
186        0.22877127, 0.97690505, 0.44438475,
187    ];
188
189    #[test]
190    fn test_distance() {
191        let first_vec = [0.238474, 0.578234];
192        let second_vec = [0.9327183, 0.387495];
193        let [v1_1, v1_2] = first_vec;
194        let [v2_1, v2_2] = second_vec;
195        let first_vec = VectorInner(&first_vec[..]);
196        let second_vec = VectorInner(&second_vec[..]);
197        {
198            assert_eq!(first_vec.magnitude(), (v1_1.powi(2) + v1_2.powi(2)).sqrt());
199            let mut normalized_vec =
200                Vector::new(&[v1_1 / first_vec.magnitude(), v1_2 / first_vec.magnitude()]);
201            assert_eq!(first_vec.normalized(), normalized_vec);
202            assert!(normalized_vec.get_mut().is_some());
203            let mut normalized_vec_clone = normalized_vec.clone();
204            assert!(normalized_vec.get_mut().is_none());
205            assert!(normalized_vec_clone.get_mut().is_none());
206            drop(normalized_vec);
207            assert!(normalized_vec_clone.get_mut().is_some());
208        }
209        assert_eq!(
210            L1Distance::distance(first_vec, second_vec),
211            (v1_1 - v2_1).abs() + (v1_2 - v2_2).abs()
212        );
213        assert_eq!(
214            L2Distance::distance(first_vec, second_vec),
215            (v1_1 - v2_1).powi(2) + (v1_2 - v2_2).powi(2)
216        );
217        assert_eq!(
218            CosineDistance::distance(first_vec, second_vec),
219            1.0 - (v1_1 * v2_1 + v1_2 * v2_2)
220                / ((v1_1.powi(2) + v1_2.powi(2)).sqrt() * (v2_1.powi(2) + v2_2.powi(2)).sqrt())
221        );
222        assert_eq!(
223            InnerProductDistance::distance(first_vec, second_vec),
224            -(v1_1 * v2_1 + v1_2 * v2_2)
225        );
226    }
227
228    #[test]
229    fn test_expect_distance() {
230        expect![[r#"
231            3.6808228
232        "#]]
233        .assert_debug_eq(&L1Distance::distance(
234            VectorInner(&VEC1),
235            VectorInner(&VEC2),
236        ));
237        expect![[r#"
238            2.054677
239        "#]]
240        .assert_debug_eq(&L2Distance::distance(
241            VectorInner(&VEC1),
242            VectorInner(&VEC2),
243        ));
244        expect![[r#"
245            0.22848958
246        "#]]
247        .assert_debug_eq(&CosineDistance::distance(
248            VectorInner(&VEC1),
249            VectorInner(&VEC2),
250        ));
251        expect![[r#"
252            -3.2870955
253        "#]]
254        .assert_debug_eq(&InnerProductDistance::distance(
255            VectorInner(&VEC1),
256            VectorInner(&VEC2),
257        ));
258    }
259}