risingwave_storage/vector/
distance.rs1use crate::vector::{
16 MeasureDistance, MeasureDistanceBuilder, VectorDistance, VectorItem, VectorRef,
17};
18
19#[macro_export]
20macro_rules! for_all_distance_measurement {
21 ($macro:ident $($param:tt)*) => {
22 $macro! {
23 {
24 (L1, $crate::vector::distance::L1Distance),
25 (L2, $crate::vector::distance::L2Distance),
26 (Cosine, $crate::vector::distance::CosineDistance),
27 (InnerProduct, $crate::vector::distance::InnerProductDistance),
28 }
29 $($param)*
30 }
31 };
32}
33
34macro_rules! define_measure {
35 ({
36 $(($distance_name:ident, $_distance_type:ty),)+
37 }) => {
38 pub enum DistanceMeasurement {
39 $($distance_name),+
40 }
41 };
42 () => {
43 for_all_distance_measurement! {define_measure}
44 };
45}
46
47define_measure!();
48
49#[macro_export]
50macro_rules! dispatch_measurement {
51 ({
52 $(($distance_name:ident, $distance_type:ty),)+
53 },
54 $measurement:expr, $type_name:ident, $body:expr) => {
55 match $measurement {
56 $(
57 DistanceMeasurement::$distance_name => {
58 type $type_name = $distance_type;
59 $body
60 }
61 ),+
62 }
63 };
64 ($measurement:expr, $type_name:ident, $body:expr) => {
65 $crate::for_all_distance_measurement! {dispatch_measurement, $measurement, $type_name, $body}
66 };
67}
68
69pub struct L1Distance;
70
71pub struct L1DistanceMeasure<'a>(VectorRef<'a>);
72
73impl MeasureDistanceBuilder for L1Distance {
74 type Measure<'a> = L1DistanceMeasure<'a>;
75
76 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
77 L1DistanceMeasure(target)
78 }
79}
80
81impl<'a> MeasureDistance<'a> for L1DistanceMeasure<'a> {
82 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
83 let len = self.0.0.len();
85 assert_eq!(len, other.0.len());
86 (0..len)
89 .map(|i| {
90 let diff = self.0.0[i] - other.0[i];
91 diff.abs()
92 })
93 .sum()
94 }
95}
96
97pub struct L2Distance;
98
99pub struct L2DistanceMeasure<'a>(VectorRef<'a>);
100
101impl MeasureDistanceBuilder for L2Distance {
102 type Measure<'a> = L2DistanceMeasure<'a>;
103
104 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
105 L2DistanceMeasure(target)
106 }
107}
108
109impl<'a> MeasureDistance<'a> for L2DistanceMeasure<'a> {
110 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
111 let len = self.0.0.len();
113 assert_eq!(len, other.0.len());
114 (0..len).map(|i| (self.0.0[i] - other.0[i]).powi(2)).sum()
117 }
118}
119
120pub struct CosineDistance;
121pub struct CosineDistanceMeasure<'a> {
122 target: VectorRef<'a>,
123 magnitude: VectorItem,
124}
125
126impl MeasureDistanceBuilder for CosineDistance {
127 type Measure<'a> = CosineDistanceMeasure<'a>;
128
129 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
130 let magnitude = target.magnitude();
131 CosineDistanceMeasure { target, magnitude }
132 }
133}
134
135impl<'a> MeasureDistance<'a> for CosineDistanceMeasure<'a> {
136 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
137 let len = self.target.0.len();
139 assert_eq!(len, other.0.len());
140 let magnitude_mul = other.magnitude() * self.magnitude;
141 1.0 - (0..len)
142 .map(|i| self.target.0[i] * other.0[i] / magnitude_mul)
143 .sum::<VectorDistance>()
144 }
145}
146
147pub struct InnerProductDistance;
148pub struct InnerProductDistanceMeasure<'a>(VectorRef<'a>);
149
150impl MeasureDistanceBuilder for InnerProductDistance {
151 type Measure<'a> = InnerProductDistanceMeasure<'a>;
152
153 fn new(target: VectorRef<'_>) -> Self::Measure<'_> {
154 InnerProductDistanceMeasure(target)
155 }
156}
157
158impl<'a> MeasureDistance<'a> for InnerProductDistanceMeasure<'a> {
159 fn measure(&self, other: VectorRef<'_>) -> VectorDistance {
160 let len = self.0.0.len();
162 assert_eq!(len, other.0.len());
163 -(0..len)
164 .map(|i| self.0.0[i] * other.0[i])
165 .sum::<VectorDistance>()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171
172 use expect_test::expect;
173
174 use super::*;
175 use crate::vector::{MeasureDistanceBuilder, Vector, VectorInner};
176
177 const VECTOR_LEN: usize = 10;
178
179 const VEC1: [f32; VECTOR_LEN] = [
180 0.45742255, 0.04135585, 0.7236407, 0.82355756, 0.837814, 0.09387952, 0.8907283, 0.20203716,
181 0.2039721, 0.7972273,
182 ];
183
184 const VEC2: [f32; VECTOR_LEN] = [
185 0.9755903, 0.42836714, 0.45131344, 0.8602846, 0.61997443, 0.9501612, 0.65076965,
186 0.22877127, 0.97690505, 0.44438475,
187 ];
188
189 #[test]
190 fn test_distance() {
191 let first_vec = [0.238474, 0.578234];
192 let second_vec = [0.9327183, 0.387495];
193 let [v1_1, v1_2] = first_vec;
194 let [v2_1, v2_2] = second_vec;
195 let first_vec = VectorInner(&first_vec[..]);
196 let second_vec = VectorInner(&second_vec[..]);
197 {
198 assert_eq!(first_vec.magnitude(), (v1_1.powi(2) + v1_2.powi(2)).sqrt());
199 let mut normalized_vec =
200 Vector::new(&[v1_1 / first_vec.magnitude(), v1_2 / first_vec.magnitude()]);
201 assert_eq!(first_vec.normalized(), normalized_vec);
202 assert!(normalized_vec.get_mut().is_some());
203 let mut normalized_vec_clone = normalized_vec.clone();
204 assert!(normalized_vec.get_mut().is_none());
205 assert!(normalized_vec_clone.get_mut().is_none());
206 drop(normalized_vec);
207 assert!(normalized_vec_clone.get_mut().is_some());
208 }
209 assert_eq!(
210 L1Distance::distance(first_vec, second_vec),
211 (v1_1 - v2_1).abs() + (v1_2 - v2_2).abs()
212 );
213 assert_eq!(
214 L2Distance::distance(first_vec, second_vec),
215 (v1_1 - v2_1).powi(2) + (v1_2 - v2_2).powi(2)
216 );
217 assert_eq!(
218 CosineDistance::distance(first_vec, second_vec),
219 1.0 - (v1_1 * v2_1 + v1_2 * v2_2)
220 / ((v1_1.powi(2) + v1_2.powi(2)).sqrt() * (v2_1.powi(2) + v2_2.powi(2)).sqrt())
221 );
222 assert_eq!(
223 InnerProductDistance::distance(first_vec, second_vec),
224 -(v1_1 * v2_1 + v1_2 * v2_2)
225 );
226 }
227
228 #[test]
229 fn test_expect_distance() {
230 expect![[r#"
231 3.6808228
232 "#]]
233 .assert_debug_eq(&L1Distance::distance(
234 VectorInner(&VEC1),
235 VectorInner(&VEC2),
236 ));
237 expect![[r#"
238 2.054677
239 "#]]
240 .assert_debug_eq(&L2Distance::distance(
241 VectorInner(&VEC1),
242 VectorInner(&VEC2),
243 ));
244 expect![[r#"
245 0.22848958
246 "#]]
247 .assert_debug_eq(&CosineDistance::distance(
248 VectorInner(&VEC1),
249 VectorInner(&VEC2),
250 ));
251 expect![[r#"
252 -3.2870955
253 "#]]
254 .assert_debug_eq(&InnerProductDistance::distance(
255 VectorInner(&VEC1),
256 VectorInner(&VEC2),
257 ));
258 }
259}