risingwave_storage/vector/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod hnsw;
16
17pub mod utils;
18
19pub use risingwave_common::array::{
20    VectorDistanceType as VectorDistance, VectorItemType as VectorItem,
21};
22pub use risingwave_common::types::{VectorRef, VectorVal as Vector};
23pub use risingwave_common::vector::{MeasureDistance, MeasureDistanceBuilder};
24
25use crate::vector::utils::BoundedNearest;
26
27pub trait OnNearestItem<O> = for<'i> Fn(VectorRef<'i>, VectorDistance, &'i [u8]) -> O;
28
29pub struct NearestBuilder<'a, O, M: MeasureDistanceBuilder> {
30    measure: M::Measure<'a>,
31    nearest: BoundedNearest<O>,
32}
33
34impl<'a, O, M: MeasureDistanceBuilder> NearestBuilder<'a, O, M> {
35    pub fn new(target: VectorRef<'a>, n: usize) -> Self {
36        assert!(n > 0);
37        NearestBuilder {
38            measure: M::new(target),
39            nearest: BoundedNearest::new(n),
40        }
41    }
42
43    pub fn add<'b>(
44        &mut self,
45        vecs: impl IntoIterator<Item = (VectorRef<'b>, &'b [u8])> + 'b,
46        on_nearest_item: impl OnNearestItem<O>,
47    ) {
48        for (vec, info) in vecs {
49            let distance = self.measure.measure(vec);
50            self.nearest
51                .insert(distance, || on_nearest_item(vec, distance, info));
52        }
53    }
54
55    pub fn finish(self) -> Vec<O> {
56        self.nearest.collect()
57    }
58}
59
60#[cfg(any(test, feature = "test"))]
61pub mod test_utils {
62    use std::cmp::min;
63
64    use bytes::Bytes;
65
66    use crate::vector::VectorDistance;
67
68    pub fn gen_info(i: usize) -> Bytes {
69        Bytes::copy_from_slice(i.to_le_bytes().as_slice())
70    }
71
72    pub fn top_n<O>(input: &mut Vec<(VectorDistance, O)>, n: usize) {
73        input.sort_by(|(first_distance, _), (second_distance, _)| {
74            first_distance.total_cmp(second_distance)
75        });
76        let n = min(n, input.len());
77        input.resize_with(n, || unreachable!());
78    }
79
80    pub use risingwave_common::test_utils::rand_array::gen_vector_for_test as gen_vector;
81}
82
83#[cfg(test)]
84mod tests {
85
86    use bytes::Bytes;
87    use itertools::Itertools;
88    use risingwave_common::array::VectorVal;
89    use risingwave_common::vector::MeasureDistanceBuilder;
90    use risingwave_common::vector::distance::L2SqrDistance;
91
92    use crate::vector::NearestBuilder;
93    use crate::vector::test_utils::{gen_info, gen_vector, top_n};
94
95    fn gen_random_input(count: usize) -> Vec<(VectorVal, Bytes)> {
96        (0..count).map(|i| (gen_vector(10), gen_info(i))).collect()
97    }
98
99    #[test]
100    fn test_empty_top_n() {
101        let vec = gen_vector(10);
102        let builder = NearestBuilder::<'_, (), L2SqrDistance>::new(vec.to_ref(), 10);
103        assert!(builder.finish().is_empty());
104    }
105
106    fn test_inner(count: usize, n: usize) {
107        let input = gen_random_input(count);
108        let vec = gen_vector(10);
109        let mut builder = NearestBuilder::<'_, _, L2SqrDistance>::new(vec.to_ref(), 10);
110        builder.add(
111            input.iter().map(|(v, b)| (v.to_ref(), b.as_ref())),
112            |_, d, b| (d, Bytes::copy_from_slice(b)),
113        );
114        let output = builder.finish();
115        let mut expected_output = input
116            .into_iter()
117            .map(|(v, b)| (L2SqrDistance::distance(vec.to_ref(), v.to_ref()), b))
118            .collect_vec();
119        top_n(&mut expected_output, n);
120        assert_eq!(output, expected_output);
121    }
122
123    #[test]
124    fn test_not_full_top_n() {
125        test_inner(5, 10);
126    }
127
128    #[test]
129    fn test_exact_size_top_n() {
130        test_inner(10, 10);
131    }
132
133    #[test]
134    fn test_oversize_top_n() {
135        test_inner(20, 10);
136    }
137}