risingwave_storage/vector/
hnsw.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::marker::PhantomData;
17
18use faiss::index::hnsw::Hnsw;
19use rand::Rng;
20use rand::distr::uniform::{UniformFloat, UniformSampler};
21
22use crate::hummock::HummockResult;
23use crate::vector::utils::{BoundedNearest, MinDistanceHeap};
24use crate::vector::{
25    MeasureDistance, MeasureDistanceBuilder, OnNearestItem, VectorDistance, VectorInner,
26    VectorItem, VectorRef,
27};
28
29pub struct HnswBuilderOptions {
30    pub m: usize,
31    pub ef_construction: usize,
32    pub max_level: usize,
33}
34
35impl HnswBuilderOptions {
36    fn level_m(&self, level: usize) -> usize {
37        // borrowed from pg_vector
38        // double the number of connections in ground level
39        if level == 0 { 2 * self.m } else { self.m }
40    }
41
42    fn m_l(&self) -> f32 {
43        1.0 / (self.m as f32).ln()
44    }
45}
46
47fn gen_level(options: &HnswBuilderOptions, rng: &mut impl Rng) -> usize {
48    let level = (-UniformFloat::<f32>::sample_single(0.0, 1.0, rng)
49        .unwrap()
50        .ln()
51        * options.m_l())
52    .floor() as usize;
53    min(level, options.max_level)
54}
55
56pub(crate) fn new_node(options: &HnswBuilderOptions, rng: &mut impl Rng) -> VectorHnswNode {
57    let level = gen_level(options, rng);
58    let mut level_neighbours = Vec::with_capacity(level);
59    level_neighbours.extend((0..=level).map(|level| BoundedNearest::new(options.level_m(level))));
60    VectorHnswNode { level_neighbours }
61}
62
63pub(crate) struct VectorHnswNode {
64    level_neighbours: Vec<BoundedNearest<usize>>,
65}
66
67impl VectorHnswNode {
68    fn level(&self) -> usize {
69        self.level_neighbours.len()
70    }
71}
72
73struct VectorStoreImpl {
74    vector_len: usize,
75    vector_payload: Vec<VectorItem>,
76    info_payload: Vec<u8>,
77    info_offsets: Vec<usize>,
78}
79
80impl VectorStoreImpl {
81    fn new(vector_len: usize) -> Self {
82        Self {
83            vector_len,
84            vector_payload: vec![],
85            info_payload: Default::default(),
86            info_offsets: vec![],
87        }
88    }
89
90    fn len(&self) -> usize {
91        self.info_offsets.len()
92    }
93
94    fn vec_ref(&self, idx: usize) -> VectorRef<'_> {
95        assert!(idx < self.info_offsets.len());
96        let start = idx * self.vector_len;
97        let end = start + self.vector_len;
98        VectorInner(&self.vector_payload[start..end])
99    }
100
101    fn info(&self, idx: usize) -> &[u8] {
102        let start = self.info_offsets[idx];
103        let end = if idx < self.info_offsets.len() - 1 {
104            self.info_offsets[idx + 1]
105        } else {
106            self.info_payload.len()
107        };
108        &self.info_payload[start..end]
109    }
110
111    fn add(&mut self, vec: VectorRef<'_>, info: &[u8]) {
112        assert_eq!(vec.0.len(), self.vector_len);
113
114        self.vector_payload.extend_from_slice(vec.0);
115        let offset = self.info_payload.len();
116        self.info_payload.extend_from_slice(info);
117        self.info_offsets.push(offset);
118    }
119}
120
121pub trait VectorAccessor {
122    fn vec_ref(&self) -> VectorRef<'_>;
123
124    fn info(&self) -> &[u8];
125}
126
127pub trait VectorStore: 'static {
128    type Accessor<'a>: VectorAccessor + 'a
129    where
130        Self: 'a;
131    async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>>;
132}
133
134pub struct VectorStoreImplAccessor<'a> {
135    vector_store_impl: &'a VectorStoreImpl,
136    idx: usize,
137}
138
139impl VectorAccessor for VectorStoreImplAccessor<'_> {
140    fn vec_ref(&self) -> VectorRef<'_> {
141        self.vector_store_impl.vec_ref(self.idx)
142    }
143
144    fn info(&self) -> &[u8] {
145        self.vector_store_impl.info(self.idx)
146    }
147}
148
149impl VectorStore for VectorStoreImpl {
150    type Accessor<'a> = VectorStoreImplAccessor<'a>;
151
152    async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>> {
153        Ok(VectorStoreImplAccessor {
154            vector_store_impl: self,
155            idx,
156        })
157    }
158}
159
160#[expect(clippy::len_without_is_empty)]
161pub trait HnswGraph {
162    fn entrypoint(&self) -> usize;
163    fn len(&self) -> usize;
164    fn node_level(&self, idx: usize) -> usize;
165    fn node_neighbours(
166        &self,
167        idx: usize,
168        level: usize,
169    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_;
170}
171
172pub struct HnswGraphBuilder {
173    /// entrypoint of the graph: Some(`entrypoint_vector_idx`)
174    entrypoint: usize,
175    nodes: Vec<VectorHnswNode>,
176}
177
178impl HnswGraphBuilder {
179    pub(crate) fn first(node: VectorHnswNode) -> Self {
180        Self {
181            entrypoint: 0,
182            nodes: vec![node],
183        }
184    }
185}
186
187impl HnswGraph for HnswGraphBuilder {
188    fn entrypoint(&self) -> usize {
189        self.entrypoint
190    }
191
192    fn len(&self) -> usize {
193        self.nodes.len()
194    }
195
196    fn node_level(&self, idx: usize) -> usize {
197        self.nodes[idx].level()
198    }
199
200    fn node_neighbours(
201        &self,
202        idx: usize,
203        level: usize,
204    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
205        (&self.nodes[idx].level_neighbours[level])
206            .into_iter()
207            .map(|(distance, &neighbour_index)| (neighbour_index, distance))
208    }
209}
210
211pub struct HnswBuilder<V: VectorStore, G: HnswGraph, M: MeasureDistanceBuilder, R: Rng> {
212    options: HnswBuilderOptions,
213
214    // payload
215    vector_store: V,
216    graph: Option<G>,
217
218    // utils
219    rng: R,
220    _measure: PhantomData<M>,
221}
222
223#[derive(Default, Debug)]
224pub struct HnswStats {
225    distances_computed: usize,
226    nhops: usize,
227}
228
229struct VecSet {
230    // TODO: optimize with bitmap
231    payload: Vec<bool>,
232}
233
234impl VecSet {
235    fn new(size: usize) -> Self {
236        Self {
237            payload: vec![false; size],
238        }
239    }
240
241    fn set(&mut self, idx: usize) {
242        self.payload[idx] = true;
243    }
244
245    fn is_set(&self, idx: usize) -> bool {
246        self.payload[idx]
247    }
248
249    fn reset(&mut self) {
250        self.payload.fill(false);
251    }
252}
253
254impl<M: MeasureDistanceBuilder, R: Rng> HnswBuilder<VectorStoreImpl, HnswGraphBuilder, M, R> {
255    pub fn new(vector_len: usize, rng: R, options: HnswBuilderOptions) -> Self {
256        Self {
257            options,
258            graph: None,
259            vector_store: VectorStoreImpl::new(vector_len),
260            rng,
261            _measure: Default::default(),
262        }
263    }
264
265    pub fn with_faiss_hnsw(self, faiss_hnsw: Hnsw<'_>) -> Self {
266        assert_eq!(self.vector_store.len(), faiss_hnsw.levels_raw().len());
267        let (entry_point, _max_level) = faiss_hnsw.entry_point().unwrap();
268        let levels = faiss_hnsw.levels_raw();
269        let Some(graph) = &self.graph else {
270            assert_eq!(levels.len(), 0);
271            return Self::new(self.vector_store.vector_len, self.rng, self.options);
272        };
273        assert_eq!(levels.len(), graph.nodes.len());
274        let mut nodes = Vec::with_capacity(graph.nodes.len());
275        for (node, level_count) in levels.iter().enumerate() {
276            let level_count = *level_count as usize;
277            let mut level_neighbors = Vec::with_capacity(level_count);
278            for level in 0..level_count {
279                let neighbors = faiss_hnsw.neighbors_raw(node, level);
280                let mut nearest_neighbors = BoundedNearest::new(neighbors.len());
281                for &neighbor in neighbors {
282                    nearest_neighbors.insert(
283                        M::distance(
284                            self.vector_store.vec_ref(node),
285                            self.vector_store.vec_ref(neighbor as _),
286                        ),
287                        || neighbor as _,
288                    );
289                }
290                level_neighbors.push(nearest_neighbors);
291            }
292            nodes.push(VectorHnswNode {
293                level_neighbours: level_neighbors,
294            });
295        }
296        Self {
297            options: self.options,
298            graph: Some(HnswGraphBuilder {
299                entrypoint: entry_point,
300                nodes,
301            }),
302            vector_store: self.vector_store,
303            rng: self.rng,
304            _measure: Default::default(),
305        }
306    }
307
308    pub fn print_graph(&self) {
309        let Some(graph) = &self.graph else {
310            println!("empty graph");
311            return;
312        };
313        println!(
314            "entrypoint {} in level {}",
315            graph.entrypoint,
316            graph.nodes[graph.entrypoint].level()
317        );
318        for (i, node) in graph.nodes.iter().enumerate() {
319            println!("node {} has {} levels", i, node.level());
320            for level in 0..node.level() {
321                print!("level {}: ", level);
322                for (_, &neighbor) in &node.level_neighbours[level] {
323                    print!("{} ", neighbor);
324                }
325                println!()
326            }
327        }
328    }
329
330    pub async fn insert(&mut self, vec: VectorRef<'_>, info: &[u8]) -> HummockResult<HnswStats> {
331        let node = new_node(&self.options, &mut self.rng);
332        let stat = if let Some(graph) = &mut self.graph {
333            insert_graph::<M>(
334                &self.vector_store,
335                graph,
336                node,
337                vec,
338                self.options.ef_construction,
339            )
340            .await?
341        } else {
342            self.graph = Some(HnswGraphBuilder::first(node));
343            HnswStats::default()
344        };
345        self.vector_store.add(vec, info);
346        Ok(stat)
347    }
348}
349
350pub(crate) async fn insert_graph<M: MeasureDistanceBuilder>(
351    vector_store: &impl VectorStore,
352    graph: &mut HnswGraphBuilder,
353    mut node: VectorHnswNode,
354    vec: VectorRef<'_>,
355    ef_construction: usize,
356) -> HummockResult<HnswStats> {
357    {
358        let mut stats = HnswStats::default();
359        let entrypoint_index = graph.entrypoint();
360        let measure = M::new(vec);
361        let mut entrypoints = BoundedNearest::new(1);
362        entrypoints.insert(
363            measure.measure(vector_store.get_vector(entrypoint_index).await?.vec_ref()),
364            || (entrypoint_index, ()),
365        );
366        let mut visited = VecSet::new(graph.nodes.len());
367        let entrypoint_level = graph.nodes[entrypoint_index].level();
368        {
369            let mut curr_level = entrypoint_level;
370            while curr_level > node.level() + 1 {
371                curr_level -= 1;
372                entrypoints = search_layer(
373                    vector_store,
374                    &*graph,
375                    &measure,
376                    |_, _, _| (),
377                    entrypoints,
378                    curr_level,
379                    1,
380                    &mut stats,
381                    &mut visited,
382                )
383                .await?;
384            }
385        }
386        {
387            let mut curr_level = min(entrypoint_level, node.level());
388            while curr_level > 0 {
389                curr_level -= 1;
390                entrypoints = search_layer(
391                    vector_store,
392                    &*graph,
393                    &measure,
394                    |_, _, _| (),
395                    entrypoints,
396                    curr_level,
397                    ef_construction,
398                    &mut stats,
399                    &mut visited,
400                )
401                .await?;
402                let level_neighbour = &mut node.level_neighbours[curr_level];
403                for (neighbour_distance, &(neighbour_index, _)) in &entrypoints {
404                    level_neighbour.insert(neighbour_distance, || neighbour_index);
405                }
406            }
407        }
408        let vector_index = graph.nodes.len();
409        for (level_index, level) in node.level_neighbours.iter().enumerate() {
410            for (neighbour_distance, &neighbour_index) in level {
411                graph.nodes[neighbour_index].level_neighbours[level_index]
412                    .insert(neighbour_distance, || vector_index);
413            }
414        }
415        if graph.nodes[entrypoint_index].level() < node.level() {
416            graph.entrypoint = vector_index;
417        }
418        graph.nodes.push(node);
419        Ok(stats)
420    }
421}
422
423pub async fn nearest<O: Send, M: MeasureDistanceBuilder>(
424    vector_store: &impl VectorStore,
425    graph: &impl HnswGraph,
426    vec: VectorRef<'_>,
427    on_nearest_fn: impl OnNearestItem<O>,
428    ef_search: usize,
429    top_n: usize,
430) -> HummockResult<(Vec<O>, HnswStats)> {
431    {
432        let entrypoint_index = graph.entrypoint();
433        let measure = M::new(vec);
434        let mut entrypoints = BoundedNearest::new(1);
435        let mut stats = HnswStats::default();
436        let entrypoint_vector = vector_store.get_vector(entrypoint_index).await?;
437        let entrypoint_distance = measure.measure(entrypoint_vector.vec_ref());
438        entrypoints.insert(entrypoint_distance, || {
439            (
440                entrypoint_index,
441                on_nearest_fn(
442                    entrypoint_vector.vec_ref(),
443                    entrypoint_distance,
444                    entrypoint_vector.info(),
445                ),
446            )
447        });
448        stats.distances_computed += 1;
449        let entrypoint_level = graph.node_level(entrypoint_index);
450        let mut visited = VecSet::new(graph.len());
451        {
452            let mut curr_level = entrypoint_level;
453            while curr_level > 1 {
454                curr_level -= 1;
455                entrypoints = search_layer(
456                    vector_store,
457                    graph,
458                    &measure,
459                    &on_nearest_fn,
460                    entrypoints,
461                    curr_level,
462                    1,
463                    &mut stats,
464                    &mut visited,
465                )
466                .await?;
467            }
468        }
469        entrypoints = search_layer(
470            vector_store,
471            graph,
472            &measure,
473            &on_nearest_fn,
474            entrypoints,
475            0,
476            ef_search,
477            &mut stats,
478            &mut visited,
479        )
480        .await?;
481        Ok((
482            entrypoints.collect_with(|(_, output)| output, Some(top_n)),
483            stats,
484        ))
485    }
486}
487
488async fn search_layer<O: Send>(
489    vector_store: &impl VectorStore,
490    graph: &impl HnswGraph,
491    measure: &impl MeasureDistance,
492    on_nearest_fn: impl OnNearestItem<O>,
493    entrypoints: BoundedNearest<(usize, O)>,
494    level_index: usize,
495    ef: usize,
496    stats: &mut HnswStats,
497    visited: &mut VecSet,
498) -> HummockResult<BoundedNearest<(usize, O)>> {
499    {
500        visited.reset();
501        let mut candidates = MinDistanceHeap::with_capacity(ef);
502        for (distance, &(idx, _)) in &entrypoints {
503            visited.set(idx);
504            candidates.push(distance, idx);
505        }
506        let mut nearest = entrypoints;
507        nearest.resize(ef);
508
509        while let Some((c_distance, c_index)) = candidates.pop() {
510            let (f_distance, _) = nearest.furthest().expect("non-empty");
511            if c_distance > f_distance {
512                // early break here when even the nearest node in `candidates` is further than the
513                // furthest node in the `nearest` set, because no node in `candidates` can be added to `nearest`
514                break;
515            }
516            stats.nhops += 1;
517            for (neighbour_index, _) in graph.node_neighbours(c_index, level_index) {
518                if visited.is_set(neighbour_index) {
519                    continue;
520                }
521                visited.set(neighbour_index);
522                let vector = vector_store.get_vector(neighbour_index).await?;
523                let info = vector.info();
524                let distance = measure.measure(vector.vec_ref());
525                stats.distances_computed += 1;
526                let mut added = false;
527                let added = &mut added;
528                nearest.insert(distance, || {
529                    *added = true;
530                    (
531                        neighbour_index,
532                        on_nearest_fn(vector.vec_ref(), distance, info),
533                    )
534                });
535                if *added {
536                    candidates.push(distance, neighbour_index);
537                }
538            }
539        }
540
541        Ok(nearest)
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use std::collections::HashSet;
548    use std::iter::repeat_with;
549    use std::time::{Duration, Instant};
550
551    use bytes::Bytes;
552    use faiss::{ConcurrentIndex, Index, MetricType};
553    use futures::executor::block_on;
554    use itertools::Itertools;
555    use rand::SeedableRng;
556    use rand::prelude::StdRng;
557
558    use crate::vector::NearestBuilder;
559    use crate::vector::distance::InnerProductDistance;
560    use crate::vector::hnsw::{HnswBuilder, HnswBuilderOptions, nearest};
561    use crate::vector::test_utils::{gen_info, gen_vector};
562
563    fn recall(actual: &Vec<Bytes>, expected: &Vec<Bytes>) -> f32 {
564        let expected: HashSet<_> = expected.iter().map(|b| b.as_ref()).collect();
565        (actual
566            .iter()
567            .filter(|info| expected.contains(info.as_ref()))
568            .count() as f32)
569            / (expected.len() as f32)
570    }
571
572    const VERBOSE: bool = false;
573    const VECTOR_LEN: usize = 128;
574    const INPUT_COUNT: usize = 20000;
575    const QUERY_COUNT: usize = 5000;
576    const TOP_N: usize = 10;
577    const EF_SEARCH_LIST: &[usize] = &[16];
578    // const EF_SEARCH_LIST: &'static [usize] = &[16, 30, 100];
579
580    #[tokio::test]
581    async fn test_hnsw_basic() {
582        let input = (0..INPUT_COUNT)
583            .map(|i| (gen_vector(VECTOR_LEN), gen_info(i)))
584            .collect_vec();
585        let m = 40;
586        let hnsw_start_time = Instant::now();
587        let mut hnsw_builder = HnswBuilder::<_, _, InnerProductDistance, _>::new(
588            VECTOR_LEN,
589            StdRng::seed_from_u64(233),
590            // StdRng::try_from_os_rng().unwrap(),
591            HnswBuilderOptions {
592                m,
593                ef_construction: 40,
594                max_level: 10,
595            },
596        );
597        for (vec, info) in &input {
598            hnsw_builder.insert(vec.to_ref(), info).await.unwrap();
599        }
600        println!("hnsw build time: {:?}", hnsw_start_time.elapsed());
601        if VERBOSE {
602            hnsw_builder.print_graph();
603        }
604
605        let faiss_hnsw_start_time = Instant::now();
606        let mut faiss_hnsw = faiss::index::hnsw::HnswFlatIndex::new(
607            VECTOR_LEN as _,
608            m as _,
609            MetricType::InnerProduct,
610        )
611        .unwrap();
612
613        faiss_hnsw
614            .add(&hnsw_builder.vector_store.vector_payload)
615            .unwrap();
616        // for (vec, info) in &input {
617        //     faiss_hnsw.add(&vec.0).unwrap();
618        // }
619
620        if VERBOSE {
621            let faiss_hnsw = faiss_hnsw.hnsw();
622            let (entry_point, max_level) = faiss_hnsw.entry_point().unwrap();
623            println!("faiss hnsw entry_point: {} {}", entry_point, max_level);
624            let levels = faiss_hnsw.levels_raw();
625            println!("entry point level: {}", levels[entry_point]);
626            for level in 0..=max_level {
627                let neighbors = faiss_hnsw.neighbors_raw(entry_point, level);
628                println!("entry point level {} neighbors {:?}", level, neighbors);
629            }
630        }
631        println!(
632            "faiss hnsw build time: {:?}",
633            faiss_hnsw_start_time.elapsed()
634        );
635
636        // let hnsw_builder = hnsw_builder.with_faiss_hnsw(faiss_hnsw.hnsw());
637
638        let queries = (0..QUERY_COUNT)
639            .map(|_| gen_vector(VECTOR_LEN))
640            .collect_vec();
641        let expected = queries
642            .iter()
643            .map(|query| {
644                let mut nearest_builder =
645                    NearestBuilder::<'_, _, InnerProductDistance>::new(query.to_ref(), TOP_N);
646                nearest_builder.add(
647                    input
648                        .iter()
649                        .map(|(vec, info)| (vec.to_ref(), info.as_ref())),
650                    |_, _, info| Bytes::copy_from_slice(info),
651                );
652                nearest_builder.finish()
653            })
654            .collect_vec();
655        let faiss_start_time = Instant::now();
656        let repeat_query = if cfg!(debug_assertions) { 1 } else { 60 };
657        println!("start faiss query");
658        let faiss_actual = repeat_with(|| queries.iter().enumerate())
659            .take(repeat_query)
660            .flatten()
661            .map(|(i, query)| {
662                let start_time = Instant::now();
663                let actual = faiss_hnsw
664                    .assign(&query.0, TOP_N)
665                    .unwrap()
666                    .labels
667                    .into_iter()
668                    .filter_map(|i| i.get().map(|i| gen_info(i as _)))
669                    .collect_vec();
670                let recall = recall(&actual, &expected[i]);
671                (start_time.elapsed(), recall)
672            })
673            .collect_vec();
674        let faiss_query_time = faiss_start_time.elapsed();
675        println!("start query");
676        let actuals = EF_SEARCH_LIST
677            .iter()
678            .map(|&ef_search| {
679                let start_time = Instant::now();
680                let actuals = repeat_with(|| queries.iter().enumerate())
681                    .take(repeat_query)
682                    .flatten()
683                    .map(|(i, query)| {
684                        let start_time = Instant::now();
685                        let (actual, stats) = block_on(nearest::<_, InnerProductDistance>(
686                            &hnsw_builder.vector_store,
687                            hnsw_builder.graph.as_ref().unwrap(),
688                            query.to_ref(),
689                            |_, _, info| Bytes::copy_from_slice(info),
690                            ef_search,
691                            TOP_N,
692                        ))
693                        .unwrap();
694                        if VERBOSE {
695                            println!("stats: {:?}", stats);
696                        }
697                        let recall = recall(&actual, &expected[i]);
698                        (start_time.elapsed(), recall)
699                    })
700                    .collect_vec();
701                (actuals, start_time.elapsed())
702            })
703            .collect_vec();
704        if VERBOSE {
705            for i in 0..20 {
706                for elapsed in [&faiss_actual]
707                    .into_iter()
708                    .chain(actuals.iter().map(|(actual, _)| actual))
709                    .map(|actual| actual[i].0)
710                {
711                    print!("{:?}\t", elapsed);
712                }
713                println!();
714                for recall in [&faiss_actual]
715                    .into_iter()
716                    .chain(actuals.iter().map(|(actual, _)| actual))
717                    .map(|actual| actual[i].1)
718                {
719                    print!("{}\t", recall);
720                }
721                println!();
722            }
723        }
724        fn avg_recall(actual: &Vec<(Duration, f32)>) -> f32 {
725            actual.iter().map(|(_, elapsed)| *elapsed).sum::<f32>() / (actual.len() as f32)
726        }
727        println!("faiss {:?} {}", faiss_query_time, avg_recall(&faiss_actual));
728        for i in 0..EF_SEARCH_LIST.len() {
729            println!(
730                "ef_search[{}] {:?} {}",
731                EF_SEARCH_LIST[i],
732                actuals[i].1,
733                avg_recall(&actuals[i].0)
734            );
735        }
736    }
737}