risingwave_storage/vector/
hnsw.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::{max, min};
16use std::marker::PhantomData;
17
18use faiss::index::hnsw::Hnsw;
19use itertools::Itertools;
20use rand::Rng;
21use rand::distr::uniform::{UniformFloat, UniformSampler};
22use risingwave_pb::hummock::PbHnswGraph;
23use risingwave_pb::hummock::hnsw_graph::{PbHnswLevel, PbHnswNeighbor, PbHnswNode};
24
25use crate::vector::utils::{BoundedNearest, MinDistanceHeap};
26use crate::vector::{
27    MeasureDistance, MeasureDistanceBuilder, OnNearestItem, VectorDistance, VectorItem, VectorRef,
28};
29
30#[derive(Copy, Clone)]
31pub struct HnswBuilderOptions {
32    pub m: usize,
33    pub ef_construction: usize,
34    pub max_level: usize,
35}
36
37pub fn level_m(m: usize, level: usize) -> usize {
38    // borrowed from pg_vector
39    // double the number of connections in ground level
40    if level == 0 { 2 * m } else { m }
41}
42
43impl HnswBuilderOptions {
44    fn m_l(&self) -> f32 {
45        1.0 / (self.m as f32).ln()
46    }
47}
48
49fn gen_level(options: &HnswBuilderOptions, rng: &mut impl Rng) -> usize {
50    let level = (-UniformFloat::<f32>::sample_single(0.0, 1.0, rng)
51        .unwrap()
52        .ln()
53        * options.m_l())
54    .floor() as usize;
55    min(level, options.max_level)
56}
57
58pub(crate) fn new_node(options: &HnswBuilderOptions, rng: &mut impl Rng) -> VectorHnswNode {
59    let level = gen_level(options, rng);
60    let mut level_neighbours = Vec::with_capacity(level + 1);
61    level_neighbours
62        .extend((0..=level).map(|level| BoundedNearest::new(level_m(options.m, level))));
63    VectorHnswNode { level_neighbours }
64}
65
66pub(crate) struct VectorHnswNode {
67    level_neighbours: Vec<BoundedNearest<usize>>,
68}
69
70impl VectorHnswNode {
71    /// Returns the number of levels this node has (levels are indexed 0..num_levels-1).
72    fn num_levels(&self) -> usize {
73        self.level_neighbours.len()
74    }
75}
76
77struct InMemoryVectorStore {
78    dimension: usize,
79    vector_payload: Vec<VectorItem>,
80    info_payload: Vec<u8>,
81    info_offsets: Vec<usize>,
82}
83
84impl InMemoryVectorStore {
85    fn new(dimension: usize) -> Self {
86        Self {
87            dimension,
88            vector_payload: vec![],
89            info_payload: Default::default(),
90            info_offsets: vec![],
91        }
92    }
93
94    fn len(&self) -> usize {
95        self.info_offsets.len()
96    }
97
98    fn vec_ref(&self, idx: usize) -> VectorRef<'_> {
99        assert!(idx < self.info_offsets.len());
100        let start = idx * self.dimension;
101        let end = start + self.dimension;
102        VectorRef::from_slice_unchecked(&self.vector_payload[start..end])
103    }
104
105    fn info(&self, idx: usize) -> &[u8] {
106        let start = self.info_offsets[idx];
107        let end = if idx < self.info_offsets.len() - 1 {
108            self.info_offsets[idx + 1]
109        } else {
110            self.info_payload.len()
111        };
112        &self.info_payload[start..end]
113    }
114
115    fn add(&mut self, vec: VectorRef<'_>, info: &[u8]) {
116        assert_eq!(vec.dimension(), self.dimension);
117
118        self.vector_payload.extend_from_slice(vec.as_slice());
119        let offset = self.info_payload.len();
120        self.info_payload.extend_from_slice(info);
121        self.info_offsets.push(offset);
122    }
123}
124
125pub trait VectorAccessor {
126    fn vec_ref(&self) -> VectorRef<'_>;
127
128    fn info(&self) -> &[u8];
129}
130
131pub trait VectorStore: 'static {
132    type Accessor<'a>: VectorAccessor + 'a
133    where
134        Self: 'a;
135    type Ctx: 'static;
136    type Err;
137    async fn get_vector(
138        &self,
139        idx: usize,
140        ctx: &mut Self::Ctx,
141    ) -> Result<Self::Accessor<'_>, Self::Err>;
142}
143
144pub struct InMemoryVectorStoreAccessor<'a> {
145    vector_store_impl: &'a InMemoryVectorStore,
146    idx: usize,
147}
148
149impl VectorAccessor for InMemoryVectorStoreAccessor<'_> {
150    fn vec_ref(&self) -> VectorRef<'_> {
151        self.vector_store_impl.vec_ref(self.idx)
152    }
153
154    fn info(&self) -> &[u8] {
155        self.vector_store_impl.info(self.idx)
156    }
157}
158
159impl VectorStore for InMemoryVectorStore {
160    type Accessor<'a> = InMemoryVectorStoreAccessor<'a>;
161    type Ctx = ();
162    type Err = !;
163
164    async fn get_vector(&self, idx: usize, _: &mut ()) -> Result<Self::Accessor<'_>, !> {
165        Ok(InMemoryVectorStoreAccessor {
166            vector_store_impl: self,
167            idx,
168        })
169    }
170}
171
172#[expect(clippy::len_without_is_empty)]
173pub trait HnswGraph {
174    fn entrypoint(&self) -> usize;
175    fn len(&self) -> usize;
176    /// Returns the number of levels for `idx` (levels are indexed 0..num_levels-1).
177    fn node_num_levels(&self, idx: usize) -> usize;
178    fn node_neighbours(
179        &self,
180        idx: usize,
181        level: usize,
182    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_;
183}
184
185impl HnswGraph for PbHnswGraph {
186    fn entrypoint(&self) -> usize {
187        self.entrypoint_id as _
188    }
189
190    fn len(&self) -> usize {
191        self.nodes.len()
192    }
193
194    fn node_num_levels(&self, idx: usize) -> usize {
195        self.nodes[idx].levels.len()
196    }
197
198    fn node_neighbours(
199        &self,
200        idx: usize,
201        level: usize,
202    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
203        self.nodes[idx]
204            .levels
205            .get(level)
206            .into_iter()
207            .flat_map(|level| {
208                level
209                    .neighbors
210                    .iter()
211                    .map(|neighbor| (neighbor.vector_id as usize, neighbor.distance))
212            })
213    }
214}
215
216pub struct HnswGraphBuilder {
217    /// entrypoint of the graph: Some(`entrypoint_vector_idx`)
218    entrypoint: usize,
219    nodes: Vec<VectorHnswNode>,
220    level_node_count: Vec<usize>,
221}
222
223fn recovery_level_count(entrypoint: usize, nodes: &[VectorHnswNode]) -> Vec<usize> {
224    let mut level_count = vec![0; nodes[entrypoint].num_levels()];
225    for node in nodes {
226        level_count[node.num_levels() - 1] += 1;
227    }
228    level_count
229}
230
231impl HnswGraphBuilder {
232    pub(crate) fn first(node: VectorHnswNode) -> Self {
233        let nodes = vec![node];
234        let level_count = recovery_level_count(0, &nodes);
235        Self {
236            entrypoint: 0,
237            nodes,
238            level_node_count: level_count,
239        }
240    }
241
242    pub fn to_protobuf(&self) -> PbHnswGraph {
243        let mut nodes = Vec::with_capacity(self.nodes.len());
244        for node in &self.nodes {
245            let mut levels = Vec::with_capacity(node.num_levels());
246            for level in &node.level_neighbours {
247                let mut neighbors = Vec::with_capacity(level.len());
248                for (distance, &neighbor_index) in level {
249                    neighbors.push(PbHnswNeighbor {
250                        vector_id: neighbor_index as u64,
251                        distance,
252                    });
253                }
254                levels.push(PbHnswLevel { neighbors });
255            }
256            nodes.push(PbHnswNode { levels });
257        }
258        PbHnswGraph {
259            entrypoint_id: self.entrypoint as u64,
260            nodes,
261        }
262    }
263
264    pub fn from_protobuf(pb: &PbHnswGraph, m: usize) -> Self {
265        let entrypoint = pb.entrypoint_id as usize;
266        let nodes = pb
267            .nodes
268            .iter()
269            .map(|node| {
270                let level_neighbours = node
271                    .levels
272                    .iter()
273                    .enumerate()
274                    .map(|(level_idx, level)| {
275                        let level_m = level_m(m, level_idx);
276                        let mut nearest = BoundedNearest::new(level_m);
277                        for neighbor in &level.neighbors {
278                            nearest.insert(neighbor.distance, || neighbor.vector_id as _);
279                        }
280                        nearest
281                    })
282                    .collect();
283                VectorHnswNode { level_neighbours }
284            })
285            .collect_vec();
286        let level_count = recovery_level_count(entrypoint, &nodes);
287        Self {
288            entrypoint,
289            nodes,
290            level_node_count: level_count,
291        }
292    }
293
294    pub fn level_node_count(&self) -> &[usize] {
295        &self.level_node_count
296    }
297}
298
299impl HnswGraph for HnswGraphBuilder {
300    fn entrypoint(&self) -> usize {
301        self.entrypoint
302    }
303
304    fn len(&self) -> usize {
305        self.nodes.len()
306    }
307
308    fn node_num_levels(&self, idx: usize) -> usize {
309        self.nodes[idx].num_levels()
310    }
311
312    fn node_neighbours(
313        &self,
314        idx: usize,
315        level: usize,
316    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
317        (&self.nodes[idx].level_neighbours[level])
318            .into_iter()
319            .map(|(distance, &neighbour_index)| (neighbour_index, distance))
320    }
321}
322
323pub struct HnswBuilder<V: VectorStore, G: HnswGraph, M: MeasureDistanceBuilder, R: Rng> {
324    options: HnswBuilderOptions,
325
326    // payload
327    vector_store: V,
328    ctx: V::Ctx,
329    graph: Option<G>,
330
331    // utils
332    rng: R,
333    _measure: PhantomData<M>,
334}
335
336#[derive(Default, Debug)]
337pub struct HnswStats {
338    pub distances_computed: usize,
339    pub n_hops: usize,
340}
341
342struct VecSet {
343    // TODO: optimize with bitmap
344    payload: Vec<bool>,
345}
346
347impl VecSet {
348    fn new(size: usize) -> Self {
349        Self {
350            payload: vec![false; size],
351        }
352    }
353
354    fn set(&mut self, idx: usize) {
355        self.payload[idx] = true;
356    }
357
358    fn is_set(&self, idx: usize) -> bool {
359        self.payload[idx]
360    }
361
362    fn reset(&mut self) {
363        self.payload.fill(false);
364    }
365}
366
367impl<M: MeasureDistanceBuilder, R: Rng> HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, M, R> {
368    pub fn new(dimension: usize, rng: R, options: HnswBuilderOptions) -> Self {
369        Self {
370            options,
371            graph: None,
372            vector_store: InMemoryVectorStore::new(dimension),
373            ctx: (),
374            rng,
375            _measure: Default::default(),
376        }
377    }
378
379    pub fn with_faiss_hnsw(self, faiss_hnsw: Hnsw<'_>) -> Self {
380        assert_eq!(self.vector_store.len(), faiss_hnsw.levels_raw().len());
381        let (entry_point, _max_level) = faiss_hnsw.entry_point().unwrap();
382        let levels = faiss_hnsw.levels_raw();
383        let Some(graph) = &self.graph else {
384            assert_eq!(levels.len(), 0);
385            return Self::new(self.vector_store.dimension, self.rng, self.options);
386        };
387        assert_eq!(levels.len(), graph.nodes.len());
388        let mut nodes = Vec::with_capacity(graph.nodes.len());
389        for (node, level_count) in levels.iter().enumerate() {
390            let level_count = *level_count as usize;
391            let mut level_neighbors = Vec::with_capacity(level_count);
392            for level_idx in 0..level_count {
393                let neighbors = faiss_hnsw.neighbors_raw(node, level_idx);
394                let mut nearest_neighbors = BoundedNearest::new(max(neighbors.len(), 1));
395                for &neighbor in neighbors {
396                    nearest_neighbors.insert(
397                        M::distance(
398                            self.vector_store.vec_ref(node),
399                            self.vector_store.vec_ref(neighbor as _),
400                        ),
401                        || neighbor as _,
402                    );
403                }
404                level_neighbors.push(nearest_neighbors);
405            }
406            nodes.push(VectorHnswNode {
407                level_neighbours: level_neighbors,
408            });
409        }
410        let level_count = recovery_level_count(entry_point, &nodes);
411        Self {
412            options: self.options,
413            graph: Some(HnswGraphBuilder {
414                entrypoint: entry_point,
415                nodes,
416                level_node_count: level_count,
417            }),
418            vector_store: self.vector_store,
419            ctx: (),
420            rng: self.rng,
421            _measure: Default::default(),
422        }
423    }
424
425    pub fn print_graph(&self) {
426        let Some(graph) = &self.graph else {
427            println!("empty graph");
428            return;
429        };
430        println!(
431            "entrypoint {} has {} levels",
432            graph.entrypoint,
433            graph.nodes[graph.entrypoint].num_levels()
434        );
435        for (i, node) in graph.nodes.iter().enumerate() {
436            println!("node {} has {} levels", i, node.num_levels());
437            for level in 0..node.num_levels() {
438                print!("level {}: ", level);
439                for (_, &neighbor) in &node.level_neighbours[level] {
440                    print!("{} ", neighbor);
441                }
442                println!()
443            }
444        }
445    }
446
447    pub async fn insert(&mut self, vec: VectorRef<'_>, info: &[u8]) -> HnswStats {
448        let node = new_node(&self.options, &mut self.rng);
449        let stat = if let Some(graph) = &mut self.graph {
450            let Ok(stat) = insert_graph::<M, _>(
451                &self.vector_store,
452                &mut self.ctx,
453                graph,
454                node,
455                vec,
456                self.options.ef_construction,
457            )
458            .await;
459            stat
460        } else {
461            self.graph = Some(HnswGraphBuilder::first(node));
462            HnswStats::default()
463        };
464        self.vector_store.add(vec, info);
465        stat
466    }
467}
468
469pub(crate) async fn insert_graph<M: MeasureDistanceBuilder, S: VectorStore>(
470    vector_store: &S,
471    ctx: &mut S::Ctx,
472    graph: &mut HnswGraphBuilder,
473    mut node: VectorHnswNode,
474    vec: VectorRef<'_>,
475    ef_construction: usize,
476) -> Result<HnswStats, S::Err> {
477    {
478        let mut stats = HnswStats::default();
479        let mut visited = VecSet::new(graph.nodes.len());
480        let entrypoint_index = graph.entrypoint();
481        let measure = M::new(vec);
482        let mut entrypoints = BoundedNearest::new(1);
483
484        entrypoints.insert(
485            measure.measure(
486                vector_store
487                    .get_vector(entrypoint_index, ctx)
488                    .await?
489                    .vec_ref(),
490            ),
491            || (entrypoint_index, ()),
492        );
493        stats.distances_computed += 1;
494
495        // Walk from entrypoint's top level down to (node_top + 1), inclusive.
496        let entry_top_level_idx = graph.nodes[entrypoint_index].num_levels() - 1;
497        let node_top_level_idx = node.num_levels() - 1;
498
499        for level_idx in ((node_top_level_idx + 1)..=entry_top_level_idx).rev() {
500            entrypoints = search_layer(
501                vector_store,
502                ctx,
503                &*graph,
504                &measure,
505                |_, _, _| (),
506                entrypoints,
507                level_idx,
508                1,
509                &mut stats,
510                &mut visited,
511            )
512            .await?;
513        }
514
515        // Connect from min(entry_top, node_top) down to ground (0).
516        let start_level_idx = min(entry_top_level_idx, node_top_level_idx);
517        for level_idx in (0..=start_level_idx).rev() {
518            entrypoints = search_layer(
519                vector_store,
520                ctx,
521                &*graph,
522                &measure,
523                |_, _, _| (),
524                entrypoints,
525                level_idx,
526                ef_construction,
527                &mut stats,
528                &mut visited,
529            )
530            .await?;
531            let level_neighbour = &mut node.level_neighbours[level_idx];
532            for (neighbour_distance, &(neighbour_index, _)) in &entrypoints {
533                level_neighbour.insert(neighbour_distance, || neighbour_index);
534            }
535        }
536
537        let vector_index = graph.nodes.len();
538        for (level_index, level) in node.level_neighbours.iter().enumerate() {
539            for (neighbour_distance, &neighbour_index) in level {
540                graph.nodes[neighbour_index].level_neighbours[level_index]
541                    .insert(neighbour_distance, || vector_index);
542            }
543        }
544        if graph.nodes[entrypoint_index].num_levels() < node.num_levels() {
545            graph.entrypoint = vector_index;
546            graph.level_node_count.resize(node.num_levels(), 0);
547        }
548        graph.level_node_count[node.num_levels() - 1] += 1;
549        graph.nodes.push(node);
550        Ok(stats)
551    }
552}
553
554pub async fn nearest<O: Send, M: MeasureDistanceBuilder, S: VectorStore>(
555    vector_store: &S,
556    ctx: &mut S::Ctx,
557    graph: &impl HnswGraph,
558    vec: VectorRef<'_>,
559    on_nearest_fn: impl OnNearestItem<O>,
560    ef_search: usize,
561    top_n: usize,
562) -> Result<(Vec<O>, HnswStats), S::Err> {
563    {
564        // Fast path: if no exploration breadth or no results requested, do nothing.
565        // Returns empty results and zeroed stats.
566        let mut stats = HnswStats::default();
567        if ef_search == 0 || top_n == 0 {
568            return Ok((Vec::new(), stats));
569        }
570
571        let entrypoint_index = graph.entrypoint();
572        let measure = M::new(vec);
573        let mut entrypoints = BoundedNearest::new(1);
574        {
575            let entrypoint_vector = vector_store.get_vector(entrypoint_index, ctx).await?;
576            let entrypoint_distance = measure.measure(entrypoint_vector.vec_ref());
577            entrypoints.insert(entrypoint_distance, || {
578                (
579                    entrypoint_index,
580                    on_nearest_fn(
581                        entrypoint_vector.vec_ref(),
582                        entrypoint_distance,
583                        entrypoint_vector.info(),
584                    ),
585                )
586            });
587        }
588        stats.distances_computed += 1;
589        let entry_top_level_idx = graph.node_num_levels(entrypoint_index) - 1;
590        let mut visited = VecSet::new(graph.len());
591        for level_idx in (1..=entry_top_level_idx).rev() {
592            entrypoints = search_layer(
593                vector_store,
594                ctx,
595                graph,
596                &measure,
597                &on_nearest_fn,
598                entrypoints,
599                level_idx, // level index
600                1,
601                &mut stats,
602                &mut visited,
603            )
604            .await?;
605        }
606        entrypoints = search_layer(
607            vector_store,
608            ctx,
609            graph,
610            &measure,
611            &on_nearest_fn,
612            entrypoints,
613            0,
614            ef_search,
615            &mut stats,
616            &mut visited,
617        )
618        .await?;
619        Ok((
620            entrypoints.collect_with(|_, (_, output)| output, Some(top_n)),
621            stats,
622        ))
623    }
624}
625
626async fn search_layer<O: Send, S: VectorStore>(
627    vector_store: &S,
628    ctx: &mut S::Ctx,
629    graph: &impl HnswGraph,
630    measure: &impl MeasureDistance,
631    on_nearest_fn: impl OnNearestItem<O>,
632    entrypoints: BoundedNearest<(usize, O)>,
633    level_index: usize,
634    ef: usize,
635    stats: &mut HnswStats,
636    visited: &mut VecSet,
637) -> Result<BoundedNearest<(usize, O)>, S::Err> {
638    #[cfg(test)]
639    {
640        __hnsw_test_hooks::record_level(level_index);
641    }
642    {
643        visited.reset();
644
645        let mut candidates = MinDistanceHeap::with_capacity(ef);
646        for (distance, &(idx, _)) in &entrypoints {
647            visited.set(idx);
648            candidates.push(distance, idx);
649        }
650        let mut nearest = entrypoints;
651        nearest.resize(ef);
652
653        while let Some((c_distance, c_index)) = candidates.pop() {
654            let (f_distance, _) = nearest.furthest().expect("non-empty");
655            if c_distance > f_distance {
656                // early break here when even the nearest node in `candidates` is further than the
657                // furthest node in the `nearest` set, because no node in `candidates` can be added to `nearest`
658                break;
659            }
660            stats.n_hops += 1;
661            for (neighbour_index, _) in graph.node_neighbours(c_index, level_index) {
662                if visited.is_set(neighbour_index) {
663                    continue;
664                }
665                visited.set(neighbour_index);
666                let vector = vector_store.get_vector(neighbour_index, ctx).await?;
667                let info = vector.info();
668
669                let distance = measure.measure(vector.vec_ref());
670
671                stats.distances_computed += 1;
672
673                let mut added = false;
674                let added = &mut added;
675                nearest.insert(distance, || {
676                    *added = true;
677                    (
678                        neighbour_index,
679                        on_nearest_fn(vector.vec_ref(), distance, info),
680                    )
681                });
682                if *added {
683                    candidates.push(distance, neighbour_index);
684                }
685            }
686        }
687
688        Ok(nearest)
689    }
690}
691
692#[cfg(test)]
693mod __hnsw_test_hooks {
694    use std::cell::RefCell;
695
696    thread_local! {
697        #[allow(clippy::missing_const_for_thread_local)]
698        static LEVELS: RefCell<Vec<usize>> = RefCell::new(Vec::new());
699    }
700
701    pub fn record_level(level: usize) {
702        LEVELS.with(|v| v.borrow_mut().push(level));
703    }
704
705    pub fn take_levels() -> Vec<usize> {
706        LEVELS.with(|v| std::mem::take(&mut *v.borrow_mut()))
707    }
708
709    pub fn clear_levels() {
710        LEVELS.with(|v| v.borrow_mut().clear());
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use std::collections::HashSet;
717    use std::iter::repeat_with;
718    use std::time::{Duration, Instant};
719
720    use bytes::Bytes;
721    use faiss::{ConcurrentIndex, Index, MetricType};
722    use futures::executor::block_on;
723    use itertools::Itertools;
724    use rand::SeedableRng;
725    use rand::prelude::StdRng;
726    use risingwave_common::types::F32;
727    use risingwave_common::util::iter_util::ZipEqDebug;
728    use risingwave_common::vector::distance::InnerProductDistance;
729
730    use super::*;
731    use crate::vector::NearestBuilder;
732    use crate::vector::hnsw::{HnswBuilder, HnswBuilderOptions, nearest};
733    use crate::vector::test_utils::{gen_info, gen_vector};
734
735    fn recall(actual: &Vec<Bytes>, expected: &Vec<Bytes>) -> f32 {
736        let expected: HashSet<_> = expected.iter().map(|b| b.as_ref()).collect();
737        (actual
738            .iter()
739            .filter(|info| expected.contains(info.as_ref()))
740            .count() as f32)
741            / (expected.len() as f32)
742    }
743
744    /// Minimal L2 distance for tests using only public traits.
745    struct TestL2;
746
747    struct TestL2Measure<'a> {
748        q: VectorRef<'a>,
749    }
750
751    impl MeasureDistanceBuilder for TestL2 {
752        type Measure<'a> = TestL2Measure<'a>;
753
754        fn new<'a>(q: VectorRef<'a>) -> Self::Measure<'a> {
755            TestL2Measure { q }
756        }
757
758        fn distance(a: VectorRef<'_>, b: VectorRef<'_>) -> VectorDistance {
759            // Sum of squared diffs over the public slice API.
760            a.as_slice()
761                .iter()
762                .zip_eq_debug(b.as_slice().iter())
763                .map(|(&x, &y)| {
764                    // VectorItem <-> f32 conversion (mirrors test_utils usage).
765                    let xf: f32 = x.into();
766                    let yf: f32 = y.into();
767                    let d = (xf as f64) - (yf as f64);
768                    d * d
769                })
770                .sum::<f64>()
771        }
772    }
773
774    impl<'a> MeasureDistance for TestL2Measure<'a> {
775        fn measure(&self, v: VectorRef<'_>) -> VectorDistance {
776            TestL2::distance(self.q, v)
777        }
778    }
779
780    fn opts(m: usize, efc: usize, max_level: usize) -> HnswBuilderOptions {
781        HnswBuilderOptions {
782            m,
783            ef_construction: efc,
784            max_level,
785        }
786    }
787
788    const VERBOSE: bool = false;
789    const VECTOR_LEN: usize = 128;
790    const INPUT_COUNT: usize = 20000;
791    const QUERY_COUNT: usize = 5000;
792    const TOP_N: usize = 10;
793    const EF_SEARCH_LIST: &[usize] = &[16];
794    // const EF_SEARCH_LIST: &'static [usize] = &[16, 30, 100];
795
796    #[cfg(not(madsim))]
797    #[tokio::test]
798    async fn test_hnsw_basic() {
799        let input = (0..INPUT_COUNT)
800            .map(|i| (gen_vector(VECTOR_LEN), gen_info(i)))
801            .collect_vec();
802        let m = 40;
803        let hnsw_start_time = Instant::now();
804        let mut hnsw_builder = HnswBuilder::<_, _, InnerProductDistance, _>::new(
805            VECTOR_LEN,
806            StdRng::seed_from_u64(233),
807            // StdRng::try_from_os_rng().unwrap(),
808            HnswBuilderOptions {
809                m,
810                ef_construction: 40,
811                max_level: 10,
812            },
813        );
814        for (vec, info) in &input {
815            hnsw_builder.insert(vec.to_ref(), info).await;
816        }
817        println!("hnsw build time: {:?}", hnsw_start_time.elapsed());
818        if VERBOSE {
819            hnsw_builder.print_graph();
820        }
821
822        let faiss_hnsw_start_time = Instant::now();
823        let mut faiss_hnsw = faiss::index::hnsw::HnswFlatIndex::new(
824            VECTOR_LEN as _,
825            m as _,
826            MetricType::InnerProduct,
827        )
828        .unwrap();
829
830        faiss_hnsw
831            .add(F32::inner_slice(&hnsw_builder.vector_store.vector_payload))
832            .unwrap();
833        // for (vec, info) in &input {
834        //     faiss_hnsw.add(&vec.0).unwrap();
835        // }
836
837        if VERBOSE {
838            let faiss_hnsw = faiss_hnsw.hnsw();
839            let (entry_point, max_level) = faiss_hnsw.entry_point().unwrap();
840            println!("faiss hnsw entry_point: {} {}", entry_point, max_level);
841            let levels = faiss_hnsw.levels_raw();
842            println!("entry point level: {}", levels[entry_point]);
843            for level in 0..=max_level {
844                let neighbors = faiss_hnsw.neighbors_raw(entry_point, level);
845                println!("entry point level {} neighbors {:?}", level, neighbors);
846            }
847        }
848        println!(
849            "faiss hnsw build time: {:?}",
850            faiss_hnsw_start_time.elapsed()
851        );
852
853        // let hnsw_builder = hnsw_builder.with_faiss_hnsw(faiss_hnsw.hnsw());
854
855        let queries = (0..QUERY_COUNT)
856            .map(|_| gen_vector(VECTOR_LEN))
857            .collect_vec();
858        let expected = queries
859            .iter()
860            .map(|query| {
861                let mut nearest_builder =
862                    NearestBuilder::<'_, _, InnerProductDistance>::new(query.to_ref(), TOP_N);
863                nearest_builder.add(
864                    input
865                        .iter()
866                        .map(|(vec, info)| (vec.to_ref(), info.as_ref())),
867                    |_, _, info| Bytes::copy_from_slice(info),
868                );
869                nearest_builder.finish()
870            })
871            .collect_vec();
872        let faiss_start_time = Instant::now();
873        let repeat_query = if cfg!(debug_assertions) { 1 } else { 60 };
874        println!("start faiss query");
875        let faiss_actual = repeat_with(|| queries.iter().enumerate())
876            .take(repeat_query)
877            .flatten()
878            .map(|(i, query)| {
879                let start_time = Instant::now();
880                let actual = faiss_hnsw
881                    .assign(query.as_raw_slice(), TOP_N)
882                    .unwrap()
883                    .labels
884                    .into_iter()
885                    .filter_map(|i| i.get().map(|i| gen_info(i as _)))
886                    .collect_vec();
887                let recall = recall(&actual, &expected[i]);
888                (start_time.elapsed(), recall)
889            })
890            .collect_vec();
891        let faiss_query_time = faiss_start_time.elapsed();
892        println!("start query");
893        let actuals = EF_SEARCH_LIST
894            .iter()
895            .map(|&ef_search| {
896                let start_time = Instant::now();
897                let actuals = repeat_with(|| queries.iter().enumerate())
898                    .take(repeat_query)
899                    .flatten()
900                    .map(|(i, query)| {
901                        let start_time = Instant::now();
902                        let (actual, stats) = block_on(nearest::<_, InnerProductDistance, _>(
903                            &hnsw_builder.vector_store,
904                            &mut (),
905                            hnsw_builder.graph.as_ref().unwrap(),
906                            query.to_ref(),
907                            |_, _, info| Bytes::copy_from_slice(info),
908                            ef_search,
909                            TOP_N,
910                        ))
911                        .unwrap();
912                        if VERBOSE {
913                            println!("stats: {:?}", stats);
914                        }
915                        let recall = recall(&actual, &expected[i]);
916                        (start_time.elapsed(), recall)
917                    })
918                    .collect_vec();
919                (actuals, start_time.elapsed())
920            })
921            .collect_vec();
922        if VERBOSE {
923            for i in 0..20 {
924                for elapsed in [&faiss_actual]
925                    .into_iter()
926                    .chain(actuals.iter().map(|(actual, _)| actual))
927                    .map(|actual| actual[i].0)
928                {
929                    print!("{:?}\t", elapsed);
930                }
931                println!();
932                for recall in [&faiss_actual]
933                    .into_iter()
934                    .chain(actuals.iter().map(|(actual, _)| actual))
935                    .map(|actual| actual[i].1)
936                {
937                    print!("{}\t", recall);
938                }
939                println!();
940            }
941        }
942        fn avg_recall(actual: &Vec<(Duration, f32)>) -> f32 {
943            actual.iter().map(|(_, elapsed)| *elapsed).sum::<f32>() / (actual.len() as f32)
944        }
945        println!("faiss {:?} {}", faiss_query_time, avg_recall(&faiss_actual));
946        for i in 0..EF_SEARCH_LIST.len() {
947            println!(
948                "ef_search[{}] {:?} {}",
949                EF_SEARCH_LIST[i],
950                actuals[i].1,
951                avg_recall(&actuals[i].0)
952            );
953        }
954    }
955
956    // Visits in insert_graph upper-layer descent should be: entry_top, entry_top-1, ..., node_top+1 (inclusive).
957    #[cfg(not(madsim))]
958    #[tokio::test]
959    async fn hnsw_insert_graph_visits_expected_upper_layers() {
960        use rand::SeedableRng;
961        use rand::rngs::StdRng;
962
963        use super::__hnsw_test_hooks as hooks;
964
965        // Use the same options helper from this test module.
966        let dim = 8;
967        let options = opts(8, 16, 8); // m, ef_construction, max_level
968
969        // Try a handful of seeds so we reliably get an entry node with >= 3 levels.
970        // This remains deterministic across runs.
971        for seed in 1u64..=200 {
972            // Fresh builder per attempt so the RNG state matches expectations.
973            let mut hnsw: HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, TestL2, StdRng> =
974                HnswBuilder::new(dim, StdRng::seed_from_u64(seed), options);
975
976            // Insert first vector: becomes the entrypoint.
977            let v0 = gen_vector(dim);
978            let _ = hnsw
979                .insert(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0))
980                .await;
981
982            // Peek current entrypoint's top level index.
983            let g = hnsw.graph.as_ref().unwrap();
984            let entry_idx = g.entrypoint();
985            let entry_top_level_idx = g.node_num_levels(entry_idx) - 1;
986
987            // We need at least 2 upper layers to make the assertion interesting.
988            if entry_top_level_idx < 2 {
989                continue; // try next seed
990            }
991
992            // Insert second vector and record which levels search_layer visits.
993            hooks::clear_levels();
994            let v1 = gen_vector(dim);
995            let _ = hnsw
996                .insert(VectorRef::from_slice_unchecked(v1.as_slice()), &gen_info(1))
997                .await;
998
999            // After insertion, read the new node's level count (it’s at the tail).
1000            let g = hnsw.graph.as_ref().unwrap();
1001            let new_idx = g.len() - 1;
1002            let node_top_level_idx = g.node_num_levels(new_idx) - 1;
1003
1004            // If the new node's top level > entry_top, HNSW would promote it to entrypoint.
1005            // That changes the descent semantics; skip such seeds to keep the assertion crisp.
1006            if node_top_level_idx > entry_top_level_idx {
1007                continue;
1008            }
1009
1010            // What the algorithm *should* have visited on the first descent:
1011            let expected: Vec<usize> = ((node_top_level_idx + 1)..=entry_top_level_idx)
1012                .rev()
1013                .collect();
1014
1015            // Extract the actual visited levels (recorded at the start of search_layer).
1016            let visited = hooks::take_levels();
1017            // Keep only *upper-layer* calls (>= 1); the final ground pass is level 0.
1018            let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
1019
1020            assert_eq!(
1021                upper, expected,
1022                "seed={seed}, entry_top_level_idx={entry_top_level_idx}, node_top_level_idx={node_top_level_idx}"
1023            );
1024            return; // success on this seed
1025        }
1026
1027        panic!(
1028            "could not find a suitable seed (entry_top_level_idx>=2 and node_top<=entry_top_level_idx) within the search window"
1029        );
1030    }
1031
1032    // Visits in nearest upper-layer descent should be: entry_top, entry_top-1, ..., 1 (then the ground pass at 0 separately).
1033    #[cfg(not(madsim))]
1034    #[tokio::test]
1035    async fn hnsw_nearest_visits_expected_upper_layers() {
1036        use super::__hnsw_test_hooks as hooks;
1037
1038        // Build minimal graph with one node of 3 levels (indices 0,1,2).
1039        let dim = 1;
1040        let mut store = InMemoryVectorStore::new(dim);
1041        let v0 = gen_vector(dim);
1042        store.add(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0));
1043
1044        let graph = HnswGraphBuilder {
1045            entrypoint: 0,
1046            nodes: vec![VectorHnswNode {
1047                level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
1048            }],
1049            level_node_count: vec![0, 0, 1],
1050        };
1051
1052        // Query doesn't matter; we just want to observe level calls.
1053        let q = gen_vector(dim);
1054
1055        hooks::clear_levels();
1056
1057        // ef_search >= 1 to ensure we do the usual traversal.
1058        let Ok((_out, _stats)) = nearest::<usize, TestL2, _>(
1059            &store,
1060            &mut (),
1061            &graph,
1062            VectorRef::from_slice_unchecked(q.as_slice()),
1063            |_v, _d, _info| 0usize,
1064            4,
1065            1,
1066        )
1067        .await;
1068
1069        let visited = hooks::take_levels();
1070
1071        // Extract only upper-layer visits (>= 1). Ground layer (0) is handled later and isn't part of this loop.
1072        let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
1073
1074        // With entry_top_level_idx = 2, we expect visits at levels [2, 1] in that order.
1075        assert_eq!(
1076            upper,
1077            vec![2, 1],
1078            "nearest should visit levels [2, 1] top-down before level 0"
1079        );
1080    }
1081
1082    #[cfg(not(madsim))]
1083    #[test]
1084    fn hnsw_vector_hnsw_node_level_returns_count() {
1085        // Construct a node with 3 level_neighbours (indices 0, 1, 2).
1086        let node = VectorHnswNode {
1087            level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
1088        };
1089
1090        // By contract, `num_level()` should return the COUNT (3), not the max index (2).
1091        assert_eq!(
1092            node.num_levels(),
1093            node.level_neighbours.len(),
1094            "VectorHnswNode::num_levels() must return the count of levels, \
1095             not the maximum level index"
1096        );
1097    }
1098
1099    #[cfg(not(madsim))]
1100    #[test]
1101    fn hnsw_graph_node_level_returns_count() {
1102        // Graph with a single node with 4 levels (indices 0,1,2,3).
1103        let graph = HnswGraphBuilder {
1104            entrypoint: 0,
1105            nodes: vec![VectorHnswNode {
1106                level_neighbours: (0..4).map(|_| BoundedNearest::new(0)).collect(),
1107            }],
1108            level_node_count: vec![0, 0, 0, 1],
1109        };
1110
1111        assert_eq!(
1112            graph.node_num_levels(0),
1113            4,
1114            "HnswGraph::node_level() must return the COUNT of levels, not the max index"
1115        );
1116    }
1117}