risingwave_storage/vector/
mod.rs1pub mod hnsw;
16
17pub mod utils;
18
19pub use risingwave_common::array::{
20 VectorDistanceType as VectorDistance, VectorItemType as VectorItem,
21};
22pub use risingwave_common::types::{VectorRef, VectorVal as Vector};
23pub use risingwave_common::vector::{MeasureDistance, MeasureDistanceBuilder};
24
25use crate::vector::utils::BoundedNearest;
26
27pub trait OnNearestItem<O> = for<'i> Fn(VectorRef<'i>, VectorDistance, &'i [u8]) -> O;
28
29pub struct NearestBuilder<'a, O, M: MeasureDistanceBuilder> {
30 measure: M::Measure<'a>,
31 nearest: BoundedNearest<O>,
32}
33
34impl<'a, O, M: MeasureDistanceBuilder> NearestBuilder<'a, O, M> {
35 pub fn new(target: VectorRef<'a>, n: usize) -> Self {
36 assert!(n > 0);
37 NearestBuilder {
38 measure: M::new(target),
39 nearest: BoundedNearest::new(n),
40 }
41 }
42
43 pub fn add<'b>(
44 &mut self,
45 vecs: impl IntoIterator<Item = (VectorRef<'b>, &'b [u8])> + 'b,
46 on_nearest_item: impl OnNearestItem<O>,
47 ) {
48 for (vec, info) in vecs {
49 let distance = self.measure.measure(vec);
50 self.nearest
51 .insert(distance, || on_nearest_item(vec, distance, info));
52 }
53 }
54
55 pub fn finish(self) -> Vec<O> {
56 self.nearest.collect()
57 }
58}
59
60#[cfg(any(test, feature = "test"))]
61pub mod test_utils {
62 use std::cmp::min;
63
64 use bytes::Bytes;
65
66 use crate::vector::VectorDistance;
67
68 pub fn gen_info(i: usize) -> Bytes {
69 Bytes::copy_from_slice(i.to_le_bytes().as_slice())
70 }
71
72 pub fn top_n<O>(input: &mut Vec<(VectorDistance, O)>, n: usize) {
73 input.sort_by(|(first_distance, _), (second_distance, _)| {
74 first_distance.total_cmp(second_distance)
75 });
76 let n = min(n, input.len());
77 input.resize_with(n, || unreachable!());
78 }
79
80 pub use risingwave_common::test_utils::rand_array::gen_vector_for_test as gen_vector;
81}
82
83#[cfg(test)]
84mod tests {
85
86 use bytes::Bytes;
87 use itertools::Itertools;
88 use risingwave_common::array::VectorVal;
89 use risingwave_common::vector::MeasureDistanceBuilder;
90 use risingwave_common::vector::distance::L2SqrDistance;
91
92 use crate::vector::NearestBuilder;
93 use crate::vector::test_utils::{gen_info, gen_vector, top_n};
94
95 fn gen_random_input(count: usize) -> Vec<(VectorVal, Bytes)> {
96 (0..count).map(|i| (gen_vector(10), gen_info(i))).collect()
97 }
98
99 #[test]
100 fn test_empty_top_n() {
101 let vec = gen_vector(10);
102 let builder = NearestBuilder::<'_, (), L2SqrDistance>::new(vec.to_ref(), 10);
103 assert!(builder.finish().is_empty());
104 }
105
106 fn test_inner(count: usize, n: usize) {
107 let input = gen_random_input(count);
108 let vec = gen_vector(10);
109 let mut builder = NearestBuilder::<'_, _, L2SqrDistance>::new(vec.to_ref(), 10);
110 builder.add(
111 input.iter().map(|(v, b)| (v.to_ref(), b.as_ref())),
112 |_, d, b| (d, Bytes::copy_from_slice(b)),
113 );
114 let output = builder.finish();
115 let mut expected_output = input
116 .into_iter()
117 .map(|(v, b)| (L2SqrDistance::distance(vec.to_ref(), v.to_ref()), b))
118 .collect_vec();
119 top_n(&mut expected_output, n);
120 assert_eq!(output, expected_output);
121 }
122
123 #[test]
124 fn test_not_full_top_n() {
125 test_inner(5, 10);
126 }
127
128 #[test]
129 fn test_exact_size_top_n() {
130 test_inner(10, 10);
131 }
132
133 #[test]
134 fn test_oversize_top_n() {
135 test_inner(20, 10);
136 }
137}