risingwave_storage/vector/
hnsw.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::{max, min};
16use std::marker::PhantomData;
17
18use faiss::index::hnsw::Hnsw;
19use rand::Rng;
20use rand::distr::uniform::{UniformFloat, UniformSampler};
21use risingwave_pb::hummock::PbHnswGraph;
22use risingwave_pb::hummock::hnsw_graph::{PbHnswLevel, PbHnswNeighbor, PbHnswNode};
23
24use crate::hummock::HummockResult;
25use crate::vector::utils::{BoundedNearest, MinDistanceHeap};
26use crate::vector::{
27    MeasureDistance, MeasureDistanceBuilder, OnNearestItem, VectorDistance, VectorItem, VectorRef,
28};
29
30#[derive(Copy, Clone)]
31pub struct HnswBuilderOptions {
32    pub m: usize,
33    pub ef_construction: usize,
34    pub max_level: usize,
35}
36
37pub fn level_m(m: usize, level: usize) -> usize {
38    // borrowed from pg_vector
39    // double the number of connections in ground level
40    if level == 0 { 2 * m } else { m }
41}
42
43impl HnswBuilderOptions {
44    fn m_l(&self) -> f32 {
45        1.0 / (self.m as f32).ln()
46    }
47}
48
49fn gen_level(options: &HnswBuilderOptions, rng: &mut impl Rng) -> usize {
50    let level = (-UniformFloat::<f32>::sample_single(0.0, 1.0, rng)
51        .unwrap()
52        .ln()
53        * options.m_l())
54    .floor() as usize;
55    min(level, options.max_level)
56}
57
58pub(crate) fn new_node(options: &HnswBuilderOptions, rng: &mut impl Rng) -> VectorHnswNode {
59    let level = gen_level(options, rng);
60    let mut level_neighbours = Vec::with_capacity(level + 1);
61    level_neighbours
62        .extend((0..=level).map(|level| BoundedNearest::new(level_m(options.m, level))));
63    VectorHnswNode { level_neighbours }
64}
65
66pub(crate) struct VectorHnswNode {
67    level_neighbours: Vec<BoundedNearest<usize>>,
68}
69
70impl VectorHnswNode {
71    /// Returns the number of levels this node has (levels are indexed 0..num_levels-1).
72    fn num_levels(&self) -> usize {
73        self.level_neighbours.len()
74    }
75}
76
77struct InMemoryVectorStore {
78    dimension: usize,
79    vector_payload: Vec<VectorItem>,
80    info_payload: Vec<u8>,
81    info_offsets: Vec<usize>,
82}
83
84impl InMemoryVectorStore {
85    fn new(dimension: usize) -> Self {
86        Self {
87            dimension,
88            vector_payload: vec![],
89            info_payload: Default::default(),
90            info_offsets: vec![],
91        }
92    }
93
94    fn len(&self) -> usize {
95        self.info_offsets.len()
96    }
97
98    fn vec_ref(&self, idx: usize) -> VectorRef<'_> {
99        assert!(idx < self.info_offsets.len());
100        let start = idx * self.dimension;
101        let end = start + self.dimension;
102        VectorRef::from_slice_unchecked(&self.vector_payload[start..end])
103    }
104
105    fn info(&self, idx: usize) -> &[u8] {
106        let start = self.info_offsets[idx];
107        let end = if idx < self.info_offsets.len() - 1 {
108            self.info_offsets[idx + 1]
109        } else {
110            self.info_payload.len()
111        };
112        &self.info_payload[start..end]
113    }
114
115    fn add(&mut self, vec: VectorRef<'_>, info: &[u8]) {
116        assert_eq!(vec.dimension(), self.dimension);
117
118        self.vector_payload.extend_from_slice(vec.as_slice());
119        let offset = self.info_payload.len();
120        self.info_payload.extend_from_slice(info);
121        self.info_offsets.push(offset);
122    }
123}
124
125pub trait VectorAccessor {
126    fn vec_ref(&self) -> VectorRef<'_>;
127
128    fn info(&self) -> &[u8];
129}
130
131pub trait VectorStore: 'static {
132    type Accessor<'a>: VectorAccessor + 'a
133    where
134        Self: 'a;
135    async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>>;
136}
137
138pub struct InMemoryVectorStoreAccessor<'a> {
139    vector_store_impl: &'a InMemoryVectorStore,
140    idx: usize,
141}
142
143impl VectorAccessor for InMemoryVectorStoreAccessor<'_> {
144    fn vec_ref(&self) -> VectorRef<'_> {
145        self.vector_store_impl.vec_ref(self.idx)
146    }
147
148    fn info(&self) -> &[u8] {
149        self.vector_store_impl.info(self.idx)
150    }
151}
152
153impl VectorStore for InMemoryVectorStore {
154    type Accessor<'a> = InMemoryVectorStoreAccessor<'a>;
155
156    async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>> {
157        Ok(InMemoryVectorStoreAccessor {
158            vector_store_impl: self,
159            idx,
160        })
161    }
162}
163
164#[expect(clippy::len_without_is_empty)]
165pub trait HnswGraph {
166    fn entrypoint(&self) -> usize;
167    fn len(&self) -> usize;
168    /// Returns the number of levels for `idx` (levels are indexed 0..num_levels-1).
169    fn node_num_levels(&self, idx: usize) -> usize;
170    fn node_neighbours(
171        &self,
172        idx: usize,
173        level: usize,
174    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_;
175}
176
177impl HnswGraph for PbHnswGraph {
178    fn entrypoint(&self) -> usize {
179        self.entrypoint_id as _
180    }
181
182    fn len(&self) -> usize {
183        self.nodes.len()
184    }
185
186    fn node_num_levels(&self, idx: usize) -> usize {
187        self.nodes[idx].levels.len()
188    }
189
190    fn node_neighbours(
191        &self,
192        idx: usize,
193        level: usize,
194    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
195        self.nodes[idx]
196            .levels
197            .get(level)
198            .into_iter()
199            .flat_map(|level| {
200                level
201                    .neighbors
202                    .iter()
203                    .map(|neighbor| (neighbor.vector_id as usize, neighbor.distance))
204            })
205    }
206}
207
208pub struct HnswGraphBuilder {
209    /// entrypoint of the graph: Some(`entrypoint_vector_idx`)
210    entrypoint: usize,
211    nodes: Vec<VectorHnswNode>,
212}
213
214impl HnswGraphBuilder {
215    pub(crate) fn first(node: VectorHnswNode) -> Self {
216        Self {
217            entrypoint: 0,
218            nodes: vec![node],
219        }
220    }
221
222    pub fn to_protobuf(&self) -> PbHnswGraph {
223        let mut nodes = Vec::with_capacity(self.nodes.len());
224        for node in &self.nodes {
225            let mut levels = Vec::with_capacity(node.num_levels());
226            for level in &node.level_neighbours {
227                let mut neighbors = Vec::with_capacity(level.len());
228                for (distance, &neighbor_index) in level {
229                    neighbors.push(PbHnswNeighbor {
230                        vector_id: neighbor_index as u64,
231                        distance,
232                    });
233                }
234                levels.push(PbHnswLevel { neighbors });
235            }
236            nodes.push(PbHnswNode { levels });
237        }
238        PbHnswGraph {
239            entrypoint_id: self.entrypoint as u64,
240            nodes,
241        }
242    }
243
244    pub fn from_protobuf(pb: &PbHnswGraph, m: usize) -> Self {
245        let entrypoint = pb.entrypoint_id as usize;
246        let nodes = pb
247            .nodes
248            .iter()
249            .map(|node| {
250                let level_neighbours = node
251                    .levels
252                    .iter()
253                    .enumerate()
254                    .map(|(level_idx, level)| {
255                        let level_m = level_m(m, level_idx);
256                        let mut nearest = BoundedNearest::new(level_m);
257                        for neighbor in &level.neighbors {
258                            nearest.insert(neighbor.distance, || neighbor.vector_id as _);
259                        }
260                        nearest
261                    })
262                    .collect();
263                VectorHnswNode { level_neighbours }
264            })
265            .collect();
266        Self { entrypoint, nodes }
267    }
268}
269
270impl HnswGraph for HnswGraphBuilder {
271    fn entrypoint(&self) -> usize {
272        self.entrypoint
273    }
274
275    fn len(&self) -> usize {
276        self.nodes.len()
277    }
278
279    fn node_num_levels(&self, idx: usize) -> usize {
280        self.nodes[idx].num_levels()
281    }
282
283    fn node_neighbours(
284        &self,
285        idx: usize,
286        level: usize,
287    ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
288        (&self.nodes[idx].level_neighbours[level])
289            .into_iter()
290            .map(|(distance, &neighbour_index)| (neighbour_index, distance))
291    }
292}
293
294pub struct HnswBuilder<V: VectorStore, G: HnswGraph, M: MeasureDistanceBuilder, R: Rng> {
295    options: HnswBuilderOptions,
296
297    // payload
298    vector_store: V,
299    graph: Option<G>,
300
301    // utils
302    rng: R,
303    _measure: PhantomData<M>,
304}
305
306#[derive(Default, Debug)]
307pub struct HnswStats {
308    distances_computed: usize,
309    nhops: usize,
310}
311
312struct VecSet {
313    // TODO: optimize with bitmap
314    payload: Vec<bool>,
315}
316
317impl VecSet {
318    fn new(size: usize) -> Self {
319        Self {
320            payload: vec![false; size],
321        }
322    }
323
324    fn set(&mut self, idx: usize) {
325        self.payload[idx] = true;
326    }
327
328    fn is_set(&self, idx: usize) -> bool {
329        self.payload[idx]
330    }
331
332    fn reset(&mut self) {
333        self.payload.fill(false);
334    }
335}
336
337impl<M: MeasureDistanceBuilder, R: Rng> HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, M, R> {
338    pub fn new(dimension: usize, rng: R, options: HnswBuilderOptions) -> Self {
339        Self {
340            options,
341            graph: None,
342            vector_store: InMemoryVectorStore::new(dimension),
343            rng,
344            _measure: Default::default(),
345        }
346    }
347
348    pub fn with_faiss_hnsw(self, faiss_hnsw: Hnsw<'_>) -> Self {
349        assert_eq!(self.vector_store.len(), faiss_hnsw.levels_raw().len());
350        let (entry_point, _max_level) = faiss_hnsw.entry_point().unwrap();
351        let levels = faiss_hnsw.levels_raw();
352        let Some(graph) = &self.graph else {
353            assert_eq!(levels.len(), 0);
354            return Self::new(self.vector_store.dimension, self.rng, self.options);
355        };
356        assert_eq!(levels.len(), graph.nodes.len());
357        let mut nodes = Vec::with_capacity(graph.nodes.len());
358        for (node, level_count) in levels.iter().enumerate() {
359            let level_count = *level_count as usize;
360            let mut level_neighbors = Vec::with_capacity(level_count);
361            for level_idx in 0..level_count {
362                let neighbors = faiss_hnsw.neighbors_raw(node, level_idx);
363                let mut nearest_neighbors = BoundedNearest::new(max(neighbors.len(), 1));
364                for &neighbor in neighbors {
365                    nearest_neighbors.insert(
366                        M::distance(
367                            self.vector_store.vec_ref(node),
368                            self.vector_store.vec_ref(neighbor as _),
369                        ),
370                        || neighbor as _,
371                    );
372                }
373                level_neighbors.push(nearest_neighbors);
374            }
375            nodes.push(VectorHnswNode {
376                level_neighbours: level_neighbors,
377            });
378        }
379        Self {
380            options: self.options,
381            graph: Some(HnswGraphBuilder {
382                entrypoint: entry_point,
383                nodes,
384            }),
385            vector_store: self.vector_store,
386            rng: self.rng,
387            _measure: Default::default(),
388        }
389    }
390
391    pub fn print_graph(&self) {
392        let Some(graph) = &self.graph else {
393            println!("empty graph");
394            return;
395        };
396        println!(
397            "entrypoint {} has {} levels",
398            graph.entrypoint,
399            graph.nodes[graph.entrypoint].num_levels()
400        );
401        for (i, node) in graph.nodes.iter().enumerate() {
402            println!("node {} has {} levels", i, node.num_levels());
403            for level in 0..node.num_levels() {
404                print!("level {}: ", level);
405                for (_, &neighbor) in &node.level_neighbours[level] {
406                    print!("{} ", neighbor);
407                }
408                println!()
409            }
410        }
411    }
412
413    pub async fn insert(&mut self, vec: VectorRef<'_>, info: &[u8]) -> HummockResult<HnswStats> {
414        let node = new_node(&self.options, &mut self.rng);
415        let stat = if let Some(graph) = &mut self.graph {
416            insert_graph::<M>(
417                &self.vector_store,
418                graph,
419                node,
420                vec,
421                self.options.ef_construction,
422            )
423            .await?
424        } else {
425            self.graph = Some(HnswGraphBuilder::first(node));
426            HnswStats::default()
427        };
428        self.vector_store.add(vec, info);
429        Ok(stat)
430    }
431}
432
433pub(crate) async fn insert_graph<M: MeasureDistanceBuilder>(
434    vector_store: &impl VectorStore,
435    graph: &mut HnswGraphBuilder,
436    mut node: VectorHnswNode,
437    vec: VectorRef<'_>,
438    ef_construction: usize,
439) -> HummockResult<HnswStats> {
440    {
441        let mut stats = HnswStats::default();
442        let mut visited = VecSet::new(graph.nodes.len());
443        let entrypoint_index = graph.entrypoint();
444        let measure = M::new(vec);
445        let mut entrypoints = BoundedNearest::new(1);
446
447        entrypoints.insert(
448            measure.measure(vector_store.get_vector(entrypoint_index).await?.vec_ref()),
449            || (entrypoint_index, ()),
450        );
451        stats.distances_computed += 1;
452
453        // Walk from entrypoint's top level down to (node_top + 1), inclusive.
454        let entry_top_level_idx = graph.nodes[entrypoint_index].num_levels() - 1;
455        let node_top_level_idx = node.num_levels() - 1;
456
457        for level_idx in ((node_top_level_idx + 1)..=entry_top_level_idx).rev() {
458            entrypoints = search_layer(
459                vector_store,
460                &*graph,
461                &measure,
462                |_, _, _| (),
463                entrypoints,
464                level_idx,
465                1,
466                &mut stats,
467                &mut visited,
468            )
469            .await?;
470        }
471
472        // Connect from min(entry_top, node_top) down to ground (0).
473        let start_level_idx = min(entry_top_level_idx, node_top_level_idx);
474        for level_idx in (0..=start_level_idx).rev() {
475            entrypoints = search_layer(
476                vector_store,
477                &*graph,
478                &measure,
479                |_, _, _| (),
480                entrypoints,
481                level_idx,
482                ef_construction,
483                &mut stats,
484                &mut visited,
485            )
486            .await?;
487            let level_neighbour = &mut node.level_neighbours[level_idx];
488            for (neighbour_distance, &(neighbour_index, _)) in &entrypoints {
489                level_neighbour.insert(neighbour_distance, || neighbour_index);
490            }
491        }
492
493        let vector_index = graph.nodes.len();
494        for (level_index, level) in node.level_neighbours.iter().enumerate() {
495            for (neighbour_distance, &neighbour_index) in level {
496                graph.nodes[neighbour_index].level_neighbours[level_index]
497                    .insert(neighbour_distance, || vector_index);
498            }
499        }
500        if graph.nodes[entrypoint_index].num_levels() < node.num_levels() {
501            graph.entrypoint = vector_index;
502        }
503        graph.nodes.push(node);
504        Ok(stats)
505    }
506}
507
508pub async fn nearest<O: Send, M: MeasureDistanceBuilder>(
509    vector_store: &impl VectorStore,
510    graph: &impl HnswGraph,
511    vec: VectorRef<'_>,
512    on_nearest_fn: impl OnNearestItem<O>,
513    ef_search: usize,
514    top_n: usize,
515) -> HummockResult<(Vec<O>, HnswStats)> {
516    {
517        // Fast path: if no exploration breadth or no results requested, do nothing.
518        // Returns empty results and zeroed stats.
519        let mut stats = HnswStats::default();
520        if ef_search == 0 || top_n == 0 {
521            return Ok((Vec::new(), stats));
522        }
523
524        let entrypoint_index = graph.entrypoint();
525        let measure = M::new(vec);
526        let mut entrypoints = BoundedNearest::new(1);
527        let entrypoint_vector = vector_store.get_vector(entrypoint_index).await?;
528        let entrypoint_distance = measure.measure(entrypoint_vector.vec_ref());
529        entrypoints.insert(entrypoint_distance, || {
530            (
531                entrypoint_index,
532                on_nearest_fn(
533                    entrypoint_vector.vec_ref(),
534                    entrypoint_distance,
535                    entrypoint_vector.info(),
536                ),
537            )
538        });
539        stats.distances_computed += 1;
540        let entry_top_level_idx = graph.node_num_levels(entrypoint_index) - 1;
541        let mut visited = VecSet::new(graph.len());
542        for level_idx in (1..=entry_top_level_idx).rev() {
543            entrypoints = search_layer(
544                vector_store,
545                graph,
546                &measure,
547                &on_nearest_fn,
548                entrypoints,
549                level_idx, // level index
550                1,
551                &mut stats,
552                &mut visited,
553            )
554            .await?;
555        }
556        entrypoints = search_layer(
557            vector_store,
558            graph,
559            &measure,
560            &on_nearest_fn,
561            entrypoints,
562            0,
563            ef_search,
564            &mut stats,
565            &mut visited,
566        )
567        .await?;
568        Ok((
569            entrypoints.collect_with(|_, (_, output)| output, Some(top_n)),
570            stats,
571        ))
572    }
573}
574
575async fn search_layer<O: Send>(
576    vector_store: &impl VectorStore,
577    graph: &impl HnswGraph,
578    measure: &impl MeasureDistance,
579    on_nearest_fn: impl OnNearestItem<O>,
580    entrypoints: BoundedNearest<(usize, O)>,
581    level_index: usize,
582    ef: usize,
583    stats: &mut HnswStats,
584    visited: &mut VecSet,
585) -> HummockResult<BoundedNearest<(usize, O)>> {
586    #[cfg(test)]
587    {
588        __hnsw_test_hooks::record_level(level_index);
589    }
590    {
591        visited.reset();
592
593        let mut candidates = MinDistanceHeap::with_capacity(ef);
594        for (distance, &(idx, _)) in &entrypoints {
595            visited.set(idx);
596            candidates.push(distance, idx);
597        }
598        let mut nearest = entrypoints;
599        nearest.resize(ef);
600
601        while let Some((c_distance, c_index)) = candidates.pop() {
602            let (f_distance, _) = nearest.furthest().expect("non-empty");
603            if c_distance > f_distance {
604                // early break here when even the nearest node in `candidates` is further than the
605                // furthest node in the `nearest` set, because no node in `candidates` can be added to `nearest`
606                break;
607            }
608            stats.nhops += 1;
609            for (neighbour_index, _) in graph.node_neighbours(c_index, level_index) {
610                if visited.is_set(neighbour_index) {
611                    continue;
612                }
613                visited.set(neighbour_index);
614                let vector = vector_store.get_vector(neighbour_index).await?;
615                let info = vector.info();
616
617                let distance = measure.measure(vector.vec_ref());
618
619                stats.distances_computed += 1;
620
621                let mut added = false;
622                let added = &mut added;
623                nearest.insert(distance, || {
624                    *added = true;
625                    (
626                        neighbour_index,
627                        on_nearest_fn(vector.vec_ref(), distance, info),
628                    )
629                });
630                if *added {
631                    candidates.push(distance, neighbour_index);
632                }
633            }
634        }
635
636        Ok(nearest)
637    }
638}
639
640#[cfg(test)]
641mod __hnsw_test_hooks {
642    use std::cell::RefCell;
643
644    thread_local! {
645        #[allow(clippy::missing_const_for_thread_local)]
646        static LEVELS: RefCell<Vec<usize>> = RefCell::new(Vec::new());
647    }
648
649    pub fn record_level(level: usize) {
650        LEVELS.with(|v| v.borrow_mut().push(level));
651    }
652
653    pub fn take_levels() -> Vec<usize> {
654        LEVELS.with(|v| std::mem::take(&mut *v.borrow_mut()))
655    }
656
657    pub fn clear_levels() {
658        LEVELS.with(|v| v.borrow_mut().clear());
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use std::collections::HashSet;
665    use std::iter::repeat_with;
666    use std::time::{Duration, Instant};
667
668    use bytes::Bytes;
669    use faiss::{ConcurrentIndex, Index, MetricType};
670    use futures::executor::block_on;
671    use itertools::Itertools;
672    use rand::SeedableRng;
673    use rand::prelude::StdRng;
674    use risingwave_common::types::F32;
675    use risingwave_common::util::iter_util::ZipEqDebug;
676    use risingwave_common::vector::distance::InnerProductDistance;
677
678    use super::*;
679    use crate::vector::NearestBuilder;
680    use crate::vector::hnsw::{HnswBuilder, HnswBuilderOptions, nearest};
681    use crate::vector::test_utils::{gen_info, gen_vector};
682
683    fn recall(actual: &Vec<Bytes>, expected: &Vec<Bytes>) -> f32 {
684        let expected: HashSet<_> = expected.iter().map(|b| b.as_ref()).collect();
685        (actual
686            .iter()
687            .filter(|info| expected.contains(info.as_ref()))
688            .count() as f32)
689            / (expected.len() as f32)
690    }
691
692    /// Minimal L2 distance for tests using only public traits.
693    struct TestL2;
694
695    struct TestL2Measure<'a> {
696        q: VectorRef<'a>,
697    }
698
699    impl MeasureDistanceBuilder for TestL2 {
700        type Measure<'a> = TestL2Measure<'a>;
701
702        fn new<'a>(q: VectorRef<'a>) -> Self::Measure<'a> {
703            TestL2Measure { q }
704        }
705
706        fn distance(a: VectorRef<'_>, b: VectorRef<'_>) -> VectorDistance {
707            // Sum of squared diffs over the public slice API.
708            a.as_slice()
709                .iter()
710                .zip_eq_debug(b.as_slice().iter())
711                .map(|(&x, &y)| {
712                    // VectorItem <-> f32 conversion (mirrors test_utils usage).
713                    let xf: f32 = x.into();
714                    let yf: f32 = y.into();
715                    let d = (xf as f64) - (yf as f64);
716                    d * d
717                })
718                .sum::<f64>()
719        }
720    }
721
722    impl<'a> MeasureDistance for TestL2Measure<'a> {
723        fn measure(&self, v: VectorRef<'_>) -> VectorDistance {
724            TestL2::distance(self.q, v)
725        }
726    }
727
728    fn opts(m: usize, efc: usize, max_level: usize) -> HnswBuilderOptions {
729        HnswBuilderOptions {
730            m,
731            ef_construction: efc,
732            max_level,
733        }
734    }
735
736    const VERBOSE: bool = false;
737    const VECTOR_LEN: usize = 128;
738    const INPUT_COUNT: usize = 20000;
739    const QUERY_COUNT: usize = 5000;
740    const TOP_N: usize = 10;
741    const EF_SEARCH_LIST: &[usize] = &[16];
742    // const EF_SEARCH_LIST: &'static [usize] = &[16, 30, 100];
743
744    #[cfg(not(madsim))]
745    #[tokio::test]
746    async fn test_hnsw_basic() {
747        let input = (0..INPUT_COUNT)
748            .map(|i| (gen_vector(VECTOR_LEN), gen_info(i)))
749            .collect_vec();
750        let m = 40;
751        let hnsw_start_time = Instant::now();
752        let mut hnsw_builder = HnswBuilder::<_, _, InnerProductDistance, _>::new(
753            VECTOR_LEN,
754            StdRng::seed_from_u64(233),
755            // StdRng::try_from_os_rng().unwrap(),
756            HnswBuilderOptions {
757                m,
758                ef_construction: 40,
759                max_level: 10,
760            },
761        );
762        for (vec, info) in &input {
763            hnsw_builder.insert(vec.to_ref(), info).await.unwrap();
764        }
765        println!("hnsw build time: {:?}", hnsw_start_time.elapsed());
766        if VERBOSE {
767            hnsw_builder.print_graph();
768        }
769
770        let faiss_hnsw_start_time = Instant::now();
771        let mut faiss_hnsw = faiss::index::hnsw::HnswFlatIndex::new(
772            VECTOR_LEN as _,
773            m as _,
774            MetricType::InnerProduct,
775        )
776        .unwrap();
777
778        faiss_hnsw
779            .add(F32::inner_slice(&hnsw_builder.vector_store.vector_payload))
780            .unwrap();
781        // for (vec, info) in &input {
782        //     faiss_hnsw.add(&vec.0).unwrap();
783        // }
784
785        if VERBOSE {
786            let faiss_hnsw = faiss_hnsw.hnsw();
787            let (entry_point, max_level) = faiss_hnsw.entry_point().unwrap();
788            println!("faiss hnsw entry_point: {} {}", entry_point, max_level);
789            let levels = faiss_hnsw.levels_raw();
790            println!("entry point level: {}", levels[entry_point]);
791            for level in 0..=max_level {
792                let neighbors = faiss_hnsw.neighbors_raw(entry_point, level);
793                println!("entry point level {} neighbors {:?}", level, neighbors);
794            }
795        }
796        println!(
797            "faiss hnsw build time: {:?}",
798            faiss_hnsw_start_time.elapsed()
799        );
800
801        // let hnsw_builder = hnsw_builder.with_faiss_hnsw(faiss_hnsw.hnsw());
802
803        let queries = (0..QUERY_COUNT)
804            .map(|_| gen_vector(VECTOR_LEN))
805            .collect_vec();
806        let expected = queries
807            .iter()
808            .map(|query| {
809                let mut nearest_builder =
810                    NearestBuilder::<'_, _, InnerProductDistance>::new(query.to_ref(), TOP_N);
811                nearest_builder.add(
812                    input
813                        .iter()
814                        .map(|(vec, info)| (vec.to_ref(), info.as_ref())),
815                    |_, _, info| Bytes::copy_from_slice(info),
816                );
817                nearest_builder.finish()
818            })
819            .collect_vec();
820        let faiss_start_time = Instant::now();
821        let repeat_query = if cfg!(debug_assertions) { 1 } else { 60 };
822        println!("start faiss query");
823        let faiss_actual = repeat_with(|| queries.iter().enumerate())
824            .take(repeat_query)
825            .flatten()
826            .map(|(i, query)| {
827                let start_time = Instant::now();
828                let actual = faiss_hnsw
829                    .assign(query.as_raw_slice(), TOP_N)
830                    .unwrap()
831                    .labels
832                    .into_iter()
833                    .filter_map(|i| i.get().map(|i| gen_info(i as _)))
834                    .collect_vec();
835                let recall = recall(&actual, &expected[i]);
836                (start_time.elapsed(), recall)
837            })
838            .collect_vec();
839        let faiss_query_time = faiss_start_time.elapsed();
840        println!("start query");
841        let actuals = EF_SEARCH_LIST
842            .iter()
843            .map(|&ef_search| {
844                let start_time = Instant::now();
845                let actuals = repeat_with(|| queries.iter().enumerate())
846                    .take(repeat_query)
847                    .flatten()
848                    .map(|(i, query)| {
849                        let start_time = Instant::now();
850                        let (actual, stats) = block_on(nearest::<_, InnerProductDistance>(
851                            &hnsw_builder.vector_store,
852                            hnsw_builder.graph.as_ref().unwrap(),
853                            query.to_ref(),
854                            |_, _, info| Bytes::copy_from_slice(info),
855                            ef_search,
856                            TOP_N,
857                        ))
858                        .unwrap();
859                        if VERBOSE {
860                            println!("stats: {:?}", stats);
861                        }
862                        let recall = recall(&actual, &expected[i]);
863                        (start_time.elapsed(), recall)
864                    })
865                    .collect_vec();
866                (actuals, start_time.elapsed())
867            })
868            .collect_vec();
869        if VERBOSE {
870            for i in 0..20 {
871                for elapsed in [&faiss_actual]
872                    .into_iter()
873                    .chain(actuals.iter().map(|(actual, _)| actual))
874                    .map(|actual| actual[i].0)
875                {
876                    print!("{:?}\t", elapsed);
877                }
878                println!();
879                for recall in [&faiss_actual]
880                    .into_iter()
881                    .chain(actuals.iter().map(|(actual, _)| actual))
882                    .map(|actual| actual[i].1)
883                {
884                    print!("{}\t", recall);
885                }
886                println!();
887            }
888        }
889        fn avg_recall(actual: &Vec<(Duration, f32)>) -> f32 {
890            actual.iter().map(|(_, elapsed)| *elapsed).sum::<f32>() / (actual.len() as f32)
891        }
892        println!("faiss {:?} {}", faiss_query_time, avg_recall(&faiss_actual));
893        for i in 0..EF_SEARCH_LIST.len() {
894            println!(
895                "ef_search[{}] {:?} {}",
896                EF_SEARCH_LIST[i],
897                actuals[i].1,
898                avg_recall(&actuals[i].0)
899            );
900        }
901    }
902
903    // Visits in insert_graph upper-layer descent should be: entry_top, entry_top-1, ..., node_top+1 (inclusive).
904    #[cfg(not(madsim))]
905    #[tokio::test]
906    async fn hnsw_insert_graph_visits_expected_upper_layers() -> HummockResult<()> {
907        use rand::SeedableRng;
908        use rand::rngs::StdRng;
909
910        use super::__hnsw_test_hooks as hooks;
911
912        // Use the same options helper from this test module.
913        let dim = 8;
914        let options = opts(8, 16, 8); // m, ef_construction, max_level
915
916        // Try a handful of seeds so we reliably get an entry node with >= 3 levels.
917        // This remains deterministic across runs.
918        for seed in 1u64..=200 {
919            // Fresh builder per attempt so the RNG state matches expectations.
920            let mut hnsw: HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, TestL2, StdRng> =
921                HnswBuilder::new(dim, StdRng::seed_from_u64(seed), options);
922
923            // Insert first vector: becomes the entrypoint.
924            let v0 = gen_vector(dim);
925            let _ = hnsw
926                .insert(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0))
927                .await?;
928
929            // Peek current entrypoint's top level index.
930            let g = hnsw.graph.as_ref().unwrap();
931            let entry_idx = g.entrypoint();
932            let entry_top_level_idx = g.node_num_levels(entry_idx) - 1;
933
934            // We need at least 2 upper layers to make the assertion interesting.
935            if entry_top_level_idx < 2 {
936                continue; // try next seed
937            }
938
939            // Insert second vector and record which levels search_layer visits.
940            hooks::clear_levels();
941            let v1 = gen_vector(dim);
942            let _ = hnsw
943                .insert(VectorRef::from_slice_unchecked(v1.as_slice()), &gen_info(1))
944                .await?;
945
946            // After insertion, read the new node's level count (it’s at the tail).
947            let g = hnsw.graph.as_ref().unwrap();
948            let new_idx = g.len() - 1;
949            let node_top_level_idx = g.node_num_levels(new_idx) - 1;
950
951            // If the new node's top level > entry_top, HNSW would promote it to entrypoint.
952            // That changes the descent semantics; skip such seeds to keep the assertion crisp.
953            if node_top_level_idx > entry_top_level_idx {
954                continue;
955            }
956
957            // What the algorithm *should* have visited on the first descent:
958            let expected: Vec<usize> = ((node_top_level_idx + 1)..=entry_top_level_idx)
959                .rev()
960                .collect();
961
962            // Extract the actual visited levels (recorded at the start of search_layer).
963            let visited = hooks::take_levels();
964            // Keep only *upper-layer* calls (>= 1); the final ground pass is level 0.
965            let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
966
967            assert_eq!(
968                upper, expected,
969                "seed={seed}, entry_top_level_idx={entry_top_level_idx}, node_top_level_idx={node_top_level_idx}"
970            );
971            return Ok(()); // success on this seed
972        }
973
974        panic!(
975            "could not find a suitable seed (entry_top_level_idx>=2 and node_top<=entry_top_level_idx) within the search window"
976        );
977    }
978
979    // Visits in nearest upper-layer descent should be: entry_top, entry_top-1, ..., 1 (then the ground pass at 0 separately).
980    #[cfg(not(madsim))]
981    #[tokio::test]
982    async fn hnsw_nearest_visits_expected_upper_layers() -> HummockResult<()> {
983        use super::__hnsw_test_hooks as hooks;
984
985        // Build minimal graph with one node of 3 levels (indices 0,1,2).
986        let dim = 1;
987        let mut store = InMemoryVectorStore::new(dim);
988        let v0 = gen_vector(dim);
989        store.add(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0));
990
991        let graph = HnswGraphBuilder {
992            entrypoint: 0,
993            nodes: vec![VectorHnswNode {
994                level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
995            }],
996        };
997
998        // Query doesn't matter; we just want to observe level calls.
999        let q = gen_vector(dim);
1000
1001        hooks::clear_levels();
1002
1003        // ef_search >= 1 to ensure we do the usual traversal.
1004        let (_out, _stats) = nearest::<usize, TestL2>(
1005            &store,
1006            &graph,
1007            VectorRef::from_slice_unchecked(q.as_slice()),
1008            |_v, _d, _info| 0usize,
1009            4,
1010            1,
1011        )
1012        .await?;
1013
1014        let visited = hooks::take_levels();
1015
1016        // Extract only upper-layer visits (>= 1). Ground layer (0) is handled later and isn't part of this loop.
1017        let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
1018
1019        // With entry_top_level_idx = 2, we expect visits at levels [2, 1] in that order.
1020        assert_eq!(
1021            upper,
1022            vec![2, 1],
1023            "nearest should visit levels [2, 1] top-down before level 0"
1024        );
1025        Ok(())
1026    }
1027
1028    #[cfg(not(madsim))]
1029    #[test]
1030    fn hnsw_vector_hnsw_node_level_returns_count() {
1031        // Construct a node with 3 level_neighbours (indices 0, 1, 2).
1032        let node = VectorHnswNode {
1033            level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
1034        };
1035
1036        // By contract, `num_level()` should return the COUNT (3), not the max index (2).
1037        assert_eq!(
1038            node.num_levels(),
1039            node.level_neighbours.len(),
1040            "VectorHnswNode::num_levels() must return the count of levels, \
1041             not the maximum level index"
1042        );
1043    }
1044
1045    #[cfg(not(madsim))]
1046    #[test]
1047    fn hnsw_graph_node_level_returns_count() {
1048        // Graph with a single node with 4 levels (indices 0,1,2,3).
1049        let graph = HnswGraphBuilder {
1050            entrypoint: 0,
1051            nodes: vec![VectorHnswNode {
1052                level_neighbours: (0..4).map(|_| BoundedNearest::new(0)).collect(),
1053            }],
1054        };
1055
1056        assert_eq!(
1057            graph.node_num_levels(0),
1058            4,
1059            "HnswGraph::node_level() must return the COUNT of levels, not the max index"
1060        );
1061    }
1062}