1use std::cmp::{max, min};
16use std::marker::PhantomData;
17
18use faiss::index::hnsw::Hnsw;
19use itertools::Itertools;
20use rand::Rng;
21use rand::distr::uniform::{UniformFloat, UniformSampler};
22use risingwave_pb::hummock::PbHnswGraph;
23use risingwave_pb::hummock::hnsw_graph::{PbHnswLevel, PbHnswNeighbor, PbHnswNode};
24
25use crate::vector::utils::{BoundedNearest, MinDistanceHeap};
26use crate::vector::{
27 MeasureDistance, MeasureDistanceBuilder, OnNearestItem, VectorDistance, VectorItem, VectorRef,
28};
29
30#[derive(Copy, Clone)]
31pub struct HnswBuilderOptions {
32 pub m: usize,
33 pub ef_construction: usize,
34 pub max_level: usize,
35}
36
37pub fn level_m(m: usize, level: usize) -> usize {
38 if level == 0 { 2 * m } else { m }
41}
42
43impl HnswBuilderOptions {
44 fn m_l(&self) -> f32 {
45 1.0 / (self.m as f32).ln()
46 }
47}
48
49fn gen_level(options: &HnswBuilderOptions, rng: &mut impl Rng) -> usize {
50 let level = (-UniformFloat::<f32>::sample_single(0.0, 1.0, rng)
51 .unwrap()
52 .ln()
53 * options.m_l())
54 .floor() as usize;
55 min(level, options.max_level)
56}
57
58pub(crate) fn new_node(options: &HnswBuilderOptions, rng: &mut impl Rng) -> VectorHnswNode {
59 let level = gen_level(options, rng);
60 let mut level_neighbours = Vec::with_capacity(level + 1);
61 level_neighbours
62 .extend((0..=level).map(|level| BoundedNearest::new(level_m(options.m, level))));
63 VectorHnswNode { level_neighbours }
64}
65
66pub(crate) struct VectorHnswNode {
67 level_neighbours: Vec<BoundedNearest<usize>>,
68}
69
70impl VectorHnswNode {
71 fn num_levels(&self) -> usize {
73 self.level_neighbours.len()
74 }
75}
76
77struct InMemoryVectorStore {
78 dimension: usize,
79 vector_payload: Vec<VectorItem>,
80 info_payload: Vec<u8>,
81 info_offsets: Vec<usize>,
82}
83
84impl InMemoryVectorStore {
85 fn new(dimension: usize) -> Self {
86 Self {
87 dimension,
88 vector_payload: vec![],
89 info_payload: Default::default(),
90 info_offsets: vec![],
91 }
92 }
93
94 fn len(&self) -> usize {
95 self.info_offsets.len()
96 }
97
98 fn vec_ref(&self, idx: usize) -> VectorRef<'_> {
99 assert!(idx < self.info_offsets.len());
100 let start = idx * self.dimension;
101 let end = start + self.dimension;
102 VectorRef::from_slice_unchecked(&self.vector_payload[start..end])
103 }
104
105 fn info(&self, idx: usize) -> &[u8] {
106 let start = self.info_offsets[idx];
107 let end = if idx < self.info_offsets.len() - 1 {
108 self.info_offsets[idx + 1]
109 } else {
110 self.info_payload.len()
111 };
112 &self.info_payload[start..end]
113 }
114
115 fn add(&mut self, vec: VectorRef<'_>, info: &[u8]) {
116 assert_eq!(vec.dimension(), self.dimension);
117
118 self.vector_payload.extend_from_slice(vec.as_slice());
119 let offset = self.info_payload.len();
120 self.info_payload.extend_from_slice(info);
121 self.info_offsets.push(offset);
122 }
123}
124
125pub trait VectorAccessor {
126 fn vec_ref(&self) -> VectorRef<'_>;
127
128 fn info(&self) -> &[u8];
129}
130
131pub trait VectorStore: 'static {
132 type Accessor<'a>: VectorAccessor + 'a
133 where
134 Self: 'a;
135 type Ctx: 'static;
136 type Err;
137 async fn get_vector(
138 &self,
139 idx: usize,
140 ctx: &mut Self::Ctx,
141 ) -> Result<Self::Accessor<'_>, Self::Err>;
142}
143
144pub struct InMemoryVectorStoreAccessor<'a> {
145 vector_store_impl: &'a InMemoryVectorStore,
146 idx: usize,
147}
148
149impl VectorAccessor for InMemoryVectorStoreAccessor<'_> {
150 fn vec_ref(&self) -> VectorRef<'_> {
151 self.vector_store_impl.vec_ref(self.idx)
152 }
153
154 fn info(&self) -> &[u8] {
155 self.vector_store_impl.info(self.idx)
156 }
157}
158
159impl VectorStore for InMemoryVectorStore {
160 type Accessor<'a> = InMemoryVectorStoreAccessor<'a>;
161 type Ctx = ();
162 type Err = !;
163
164 async fn get_vector(&self, idx: usize, _: &mut ()) -> Result<Self::Accessor<'_>, !> {
165 Ok(InMemoryVectorStoreAccessor {
166 vector_store_impl: self,
167 idx,
168 })
169 }
170}
171
172#[expect(clippy::len_without_is_empty)]
173pub trait HnswGraph {
174 fn entrypoint(&self) -> usize;
175 fn len(&self) -> usize;
176 fn node_num_levels(&self, idx: usize) -> usize;
178 fn node_neighbours(
179 &self,
180 idx: usize,
181 level: usize,
182 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_;
183}
184
185impl HnswGraph for PbHnswGraph {
186 fn entrypoint(&self) -> usize {
187 self.entrypoint_id as _
188 }
189
190 fn len(&self) -> usize {
191 self.nodes.len()
192 }
193
194 fn node_num_levels(&self, idx: usize) -> usize {
195 self.nodes[idx].levels.len()
196 }
197
198 fn node_neighbours(
199 &self,
200 idx: usize,
201 level: usize,
202 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
203 self.nodes[idx]
204 .levels
205 .get(level)
206 .into_iter()
207 .flat_map(|level| {
208 level
209 .neighbors
210 .iter()
211 .map(|neighbor| (neighbor.vector_id as usize, neighbor.distance))
212 })
213 }
214}
215
216pub struct HnswGraphBuilder {
217 entrypoint: usize,
219 nodes: Vec<VectorHnswNode>,
220 level_node_count: Vec<usize>,
221}
222
223fn recovery_level_count(entrypoint: usize, nodes: &[VectorHnswNode]) -> Vec<usize> {
224 let mut level_count = vec![0; nodes[entrypoint].num_levels()];
225 for node in nodes {
226 level_count[node.num_levels() - 1] += 1;
227 }
228 level_count
229}
230
231impl HnswGraphBuilder {
232 pub(crate) fn first(node: VectorHnswNode) -> Self {
233 let nodes = vec![node];
234 let level_count = recovery_level_count(0, &nodes);
235 Self {
236 entrypoint: 0,
237 nodes,
238 level_node_count: level_count,
239 }
240 }
241
242 pub fn to_protobuf(&self) -> PbHnswGraph {
243 let mut nodes = Vec::with_capacity(self.nodes.len());
244 for node in &self.nodes {
245 let mut levels = Vec::with_capacity(node.num_levels());
246 for level in &node.level_neighbours {
247 let mut neighbors = Vec::with_capacity(level.len());
248 for (distance, &neighbor_index) in level {
249 neighbors.push(PbHnswNeighbor {
250 vector_id: neighbor_index as u64,
251 distance,
252 });
253 }
254 levels.push(PbHnswLevel { neighbors });
255 }
256 nodes.push(PbHnswNode { levels });
257 }
258 PbHnswGraph {
259 entrypoint_id: self.entrypoint as u64,
260 nodes,
261 }
262 }
263
264 pub fn from_protobuf(pb: &PbHnswGraph, m: usize) -> Self {
265 let entrypoint = pb.entrypoint_id as usize;
266 let nodes = pb
267 .nodes
268 .iter()
269 .map(|node| {
270 let level_neighbours = node
271 .levels
272 .iter()
273 .enumerate()
274 .map(|(level_idx, level)| {
275 let level_m = level_m(m, level_idx);
276 let mut nearest = BoundedNearest::new(level_m);
277 for neighbor in &level.neighbors {
278 nearest.insert(neighbor.distance, || neighbor.vector_id as _);
279 }
280 nearest
281 })
282 .collect();
283 VectorHnswNode { level_neighbours }
284 })
285 .collect_vec();
286 let level_count = recovery_level_count(entrypoint, &nodes);
287 Self {
288 entrypoint,
289 nodes,
290 level_node_count: level_count,
291 }
292 }
293
294 pub fn level_node_count(&self) -> &[usize] {
295 &self.level_node_count
296 }
297}
298
299impl HnswGraph for HnswGraphBuilder {
300 fn entrypoint(&self) -> usize {
301 self.entrypoint
302 }
303
304 fn len(&self) -> usize {
305 self.nodes.len()
306 }
307
308 fn node_num_levels(&self, idx: usize) -> usize {
309 self.nodes[idx].num_levels()
310 }
311
312 fn node_neighbours(
313 &self,
314 idx: usize,
315 level: usize,
316 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
317 (&self.nodes[idx].level_neighbours[level])
318 .into_iter()
319 .map(|(distance, &neighbour_index)| (neighbour_index, distance))
320 }
321}
322
323pub struct HnswBuilder<V: VectorStore, G: HnswGraph, M: MeasureDistanceBuilder, R: Rng> {
324 options: HnswBuilderOptions,
325
326 vector_store: V,
328 ctx: V::Ctx,
329 graph: Option<G>,
330
331 rng: R,
333 _measure: PhantomData<M>,
334}
335
336#[derive(Default, Debug)]
337pub struct HnswStats {
338 pub distances_computed: usize,
339 pub n_hops: usize,
340}
341
342struct VecSet {
343 payload: Vec<bool>,
345}
346
347impl VecSet {
348 fn new(size: usize) -> Self {
349 Self {
350 payload: vec![false; size],
351 }
352 }
353
354 fn set(&mut self, idx: usize) {
355 self.payload[idx] = true;
356 }
357
358 fn is_set(&self, idx: usize) -> bool {
359 self.payload[idx]
360 }
361
362 fn reset(&mut self) {
363 self.payload.fill(false);
364 }
365}
366
367impl<M: MeasureDistanceBuilder, R: Rng> HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, M, R> {
368 pub fn new(dimension: usize, rng: R, options: HnswBuilderOptions) -> Self {
369 Self {
370 options,
371 graph: None,
372 vector_store: InMemoryVectorStore::new(dimension),
373 ctx: (),
374 rng,
375 _measure: Default::default(),
376 }
377 }
378
379 pub fn with_faiss_hnsw(self, faiss_hnsw: Hnsw<'_>) -> Self {
380 assert_eq!(self.vector_store.len(), faiss_hnsw.levels_raw().len());
381 let (entry_point, _max_level) = faiss_hnsw.entry_point().unwrap();
382 let levels = faiss_hnsw.levels_raw();
383 let Some(graph) = &self.graph else {
384 assert_eq!(levels.len(), 0);
385 return Self::new(self.vector_store.dimension, self.rng, self.options);
386 };
387 assert_eq!(levels.len(), graph.nodes.len());
388 let mut nodes = Vec::with_capacity(graph.nodes.len());
389 for (node, level_count) in levels.iter().enumerate() {
390 let level_count = *level_count as usize;
391 let mut level_neighbors = Vec::with_capacity(level_count);
392 for level_idx in 0..level_count {
393 let neighbors = faiss_hnsw.neighbors_raw(node, level_idx);
394 let mut nearest_neighbors = BoundedNearest::new(max(neighbors.len(), 1));
395 for &neighbor in neighbors {
396 nearest_neighbors.insert(
397 M::distance(
398 self.vector_store.vec_ref(node),
399 self.vector_store.vec_ref(neighbor as _),
400 ),
401 || neighbor as _,
402 );
403 }
404 level_neighbors.push(nearest_neighbors);
405 }
406 nodes.push(VectorHnswNode {
407 level_neighbours: level_neighbors,
408 });
409 }
410 let level_count = recovery_level_count(entry_point, &nodes);
411 Self {
412 options: self.options,
413 graph: Some(HnswGraphBuilder {
414 entrypoint: entry_point,
415 nodes,
416 level_node_count: level_count,
417 }),
418 vector_store: self.vector_store,
419 ctx: (),
420 rng: self.rng,
421 _measure: Default::default(),
422 }
423 }
424
425 pub fn print_graph(&self) {
426 let Some(graph) = &self.graph else {
427 println!("empty graph");
428 return;
429 };
430 println!(
431 "entrypoint {} has {} levels",
432 graph.entrypoint,
433 graph.nodes[graph.entrypoint].num_levels()
434 );
435 for (i, node) in graph.nodes.iter().enumerate() {
436 println!("node {} has {} levels", i, node.num_levels());
437 for level in 0..node.num_levels() {
438 print!("level {}: ", level);
439 for (_, &neighbor) in &node.level_neighbours[level] {
440 print!("{} ", neighbor);
441 }
442 println!()
443 }
444 }
445 }
446
447 pub async fn insert(&mut self, vec: VectorRef<'_>, info: &[u8]) -> HnswStats {
448 let node = new_node(&self.options, &mut self.rng);
449 let stat = if let Some(graph) = &mut self.graph {
450 let Ok(stat) = insert_graph::<M, _>(
451 &self.vector_store,
452 &mut self.ctx,
453 graph,
454 node,
455 vec,
456 self.options.ef_construction,
457 )
458 .await;
459 stat
460 } else {
461 self.graph = Some(HnswGraphBuilder::first(node));
462 HnswStats::default()
463 };
464 self.vector_store.add(vec, info);
465 stat
466 }
467}
468
469pub(crate) async fn insert_graph<M: MeasureDistanceBuilder, S: VectorStore>(
470 vector_store: &S,
471 ctx: &mut S::Ctx,
472 graph: &mut HnswGraphBuilder,
473 mut node: VectorHnswNode,
474 vec: VectorRef<'_>,
475 ef_construction: usize,
476) -> Result<HnswStats, S::Err> {
477 {
478 let mut stats = HnswStats::default();
479 let mut visited = VecSet::new(graph.nodes.len());
480 let entrypoint_index = graph.entrypoint();
481 let measure = M::new(vec);
482 let mut entrypoints = BoundedNearest::new(1);
483
484 entrypoints.insert(
485 measure.measure(
486 vector_store
487 .get_vector(entrypoint_index, ctx)
488 .await?
489 .vec_ref(),
490 ),
491 || (entrypoint_index, ()),
492 );
493 stats.distances_computed += 1;
494
495 let entry_top_level_idx = graph.nodes[entrypoint_index].num_levels() - 1;
497 let node_top_level_idx = node.num_levels() - 1;
498
499 for level_idx in ((node_top_level_idx + 1)..=entry_top_level_idx).rev() {
500 entrypoints = search_layer(
501 vector_store,
502 ctx,
503 &*graph,
504 &measure,
505 |_, _, _| (),
506 entrypoints,
507 level_idx,
508 1,
509 &mut stats,
510 &mut visited,
511 )
512 .await?;
513 }
514
515 let start_level_idx = min(entry_top_level_idx, node_top_level_idx);
517 for level_idx in (0..=start_level_idx).rev() {
518 entrypoints = search_layer(
519 vector_store,
520 ctx,
521 &*graph,
522 &measure,
523 |_, _, _| (),
524 entrypoints,
525 level_idx,
526 ef_construction,
527 &mut stats,
528 &mut visited,
529 )
530 .await?;
531 let level_neighbour = &mut node.level_neighbours[level_idx];
532 for (neighbour_distance, &(neighbour_index, _)) in &entrypoints {
533 level_neighbour.insert(neighbour_distance, || neighbour_index);
534 }
535 }
536
537 let vector_index = graph.nodes.len();
538 for (level_index, level) in node.level_neighbours.iter().enumerate() {
539 for (neighbour_distance, &neighbour_index) in level {
540 graph.nodes[neighbour_index].level_neighbours[level_index]
541 .insert(neighbour_distance, || vector_index);
542 }
543 }
544 if graph.nodes[entrypoint_index].num_levels() < node.num_levels() {
545 graph.entrypoint = vector_index;
546 graph.level_node_count.resize(node.num_levels(), 0);
547 }
548 graph.level_node_count[node.num_levels() - 1] += 1;
549 graph.nodes.push(node);
550 Ok(stats)
551 }
552}
553
554pub async fn nearest<O: Send, M: MeasureDistanceBuilder, S: VectorStore>(
555 vector_store: &S,
556 ctx: &mut S::Ctx,
557 graph: &impl HnswGraph,
558 vec: VectorRef<'_>,
559 on_nearest_fn: impl OnNearestItem<O>,
560 ef_search: usize,
561 top_n: usize,
562) -> Result<(Vec<O>, HnswStats), S::Err> {
563 {
564 let mut stats = HnswStats::default();
567 if ef_search == 0 || top_n == 0 {
568 return Ok((Vec::new(), stats));
569 }
570
571 let entrypoint_index = graph.entrypoint();
572 let measure = M::new(vec);
573 let mut entrypoints = BoundedNearest::new(1);
574 {
575 let entrypoint_vector = vector_store.get_vector(entrypoint_index, ctx).await?;
576 let entrypoint_distance = measure.measure(entrypoint_vector.vec_ref());
577 entrypoints.insert(entrypoint_distance, || {
578 (
579 entrypoint_index,
580 on_nearest_fn(
581 entrypoint_vector.vec_ref(),
582 entrypoint_distance,
583 entrypoint_vector.info(),
584 ),
585 )
586 });
587 }
588 stats.distances_computed += 1;
589 let entry_top_level_idx = graph.node_num_levels(entrypoint_index) - 1;
590 let mut visited = VecSet::new(graph.len());
591 for level_idx in (1..=entry_top_level_idx).rev() {
592 entrypoints = search_layer(
593 vector_store,
594 ctx,
595 graph,
596 &measure,
597 &on_nearest_fn,
598 entrypoints,
599 level_idx, 1,
601 &mut stats,
602 &mut visited,
603 )
604 .await?;
605 }
606 entrypoints = search_layer(
607 vector_store,
608 ctx,
609 graph,
610 &measure,
611 &on_nearest_fn,
612 entrypoints,
613 0,
614 ef_search,
615 &mut stats,
616 &mut visited,
617 )
618 .await?;
619 Ok((
620 entrypoints.collect_with(|_, (_, output)| output, Some(top_n)),
621 stats,
622 ))
623 }
624}
625
626async fn search_layer<O: Send, S: VectorStore>(
627 vector_store: &S,
628 ctx: &mut S::Ctx,
629 graph: &impl HnswGraph,
630 measure: &impl MeasureDistance,
631 on_nearest_fn: impl OnNearestItem<O>,
632 entrypoints: BoundedNearest<(usize, O)>,
633 level_index: usize,
634 ef: usize,
635 stats: &mut HnswStats,
636 visited: &mut VecSet,
637) -> Result<BoundedNearest<(usize, O)>, S::Err> {
638 #[cfg(test)]
639 {
640 __hnsw_test_hooks::record_level(level_index);
641 }
642 {
643 visited.reset();
644
645 let mut candidates = MinDistanceHeap::with_capacity(ef);
646 for (distance, &(idx, _)) in &entrypoints {
647 visited.set(idx);
648 candidates.push(distance, idx);
649 }
650 let mut nearest = entrypoints;
651 nearest.resize(ef);
652
653 while let Some((c_distance, c_index)) = candidates.pop() {
654 let (f_distance, _) = nearest.furthest().expect("non-empty");
655 if c_distance > f_distance {
656 break;
659 }
660 stats.n_hops += 1;
661 for (neighbour_index, _) in graph.node_neighbours(c_index, level_index) {
662 if visited.is_set(neighbour_index) {
663 continue;
664 }
665 visited.set(neighbour_index);
666 let vector = vector_store.get_vector(neighbour_index, ctx).await?;
667 let info = vector.info();
668
669 let distance = measure.measure(vector.vec_ref());
670
671 stats.distances_computed += 1;
672
673 let mut added = false;
674 let added = &mut added;
675 nearest.insert(distance, || {
676 *added = true;
677 (
678 neighbour_index,
679 on_nearest_fn(vector.vec_ref(), distance, info),
680 )
681 });
682 if *added {
683 candidates.push(distance, neighbour_index);
684 }
685 }
686 }
687
688 Ok(nearest)
689 }
690}
691
692#[cfg(test)]
693mod __hnsw_test_hooks {
694 use std::cell::RefCell;
695
696 thread_local! {
697 #[allow(clippy::missing_const_for_thread_local)]
698 static LEVELS: RefCell<Vec<usize>> = RefCell::new(Vec::new());
699 }
700
701 pub fn record_level(level: usize) {
702 LEVELS.with(|v| v.borrow_mut().push(level));
703 }
704
705 pub fn take_levels() -> Vec<usize> {
706 LEVELS.with(|v| std::mem::take(&mut *v.borrow_mut()))
707 }
708
709 pub fn clear_levels() {
710 LEVELS.with(|v| v.borrow_mut().clear());
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use std::collections::HashSet;
717 use std::iter::repeat_with;
718 use std::time::{Duration, Instant};
719
720 use bytes::Bytes;
721 use faiss::{ConcurrentIndex, Index, MetricType};
722 use futures::executor::block_on;
723 use itertools::Itertools;
724 use rand::SeedableRng;
725 use rand::prelude::StdRng;
726 use risingwave_common::types::F32;
727 use risingwave_common::util::iter_util::ZipEqDebug;
728 use risingwave_common::vector::distance::InnerProductDistance;
729
730 use super::*;
731 use crate::vector::NearestBuilder;
732 use crate::vector::hnsw::{HnswBuilder, HnswBuilderOptions, nearest};
733 use crate::vector::test_utils::{gen_info, gen_vector};
734
735 fn recall(actual: &Vec<Bytes>, expected: &Vec<Bytes>) -> f32 {
736 let expected: HashSet<_> = expected.iter().map(|b| b.as_ref()).collect();
737 (actual
738 .iter()
739 .filter(|info| expected.contains(info.as_ref()))
740 .count() as f32)
741 / (expected.len() as f32)
742 }
743
744 struct TestL2;
746
747 struct TestL2Measure<'a> {
748 q: VectorRef<'a>,
749 }
750
751 impl MeasureDistanceBuilder for TestL2 {
752 type Measure<'a> = TestL2Measure<'a>;
753
754 fn new<'a>(q: VectorRef<'a>) -> Self::Measure<'a> {
755 TestL2Measure { q }
756 }
757
758 fn distance(a: VectorRef<'_>, b: VectorRef<'_>) -> VectorDistance {
759 a.as_slice()
761 .iter()
762 .zip_eq_debug(b.as_slice().iter())
763 .map(|(&x, &y)| {
764 let xf: f32 = x.into();
766 let yf: f32 = y.into();
767 let d = (xf as f64) - (yf as f64);
768 d * d
769 })
770 .sum::<f64>()
771 }
772 }
773
774 impl<'a> MeasureDistance for TestL2Measure<'a> {
775 fn measure(&self, v: VectorRef<'_>) -> VectorDistance {
776 TestL2::distance(self.q, v)
777 }
778 }
779
780 fn opts(m: usize, efc: usize, max_level: usize) -> HnswBuilderOptions {
781 HnswBuilderOptions {
782 m,
783 ef_construction: efc,
784 max_level,
785 }
786 }
787
788 const VERBOSE: bool = false;
789 const VECTOR_LEN: usize = 128;
790 const INPUT_COUNT: usize = 20000;
791 const QUERY_COUNT: usize = 5000;
792 const TOP_N: usize = 10;
793 const EF_SEARCH_LIST: &[usize] = &[16];
794 #[cfg(not(madsim))]
797 #[tokio::test]
798 async fn test_hnsw_basic() {
799 let input = (0..INPUT_COUNT)
800 .map(|i| (gen_vector(VECTOR_LEN), gen_info(i)))
801 .collect_vec();
802 let m = 40;
803 let hnsw_start_time = Instant::now();
804 let mut hnsw_builder = HnswBuilder::<_, _, InnerProductDistance, _>::new(
805 VECTOR_LEN,
806 StdRng::seed_from_u64(233),
807 HnswBuilderOptions {
809 m,
810 ef_construction: 40,
811 max_level: 10,
812 },
813 );
814 for (vec, info) in &input {
815 hnsw_builder.insert(vec.to_ref(), info).await;
816 }
817 println!("hnsw build time: {:?}", hnsw_start_time.elapsed());
818 if VERBOSE {
819 hnsw_builder.print_graph();
820 }
821
822 let faiss_hnsw_start_time = Instant::now();
823 let mut faiss_hnsw = faiss::index::hnsw::HnswFlatIndex::new(
824 VECTOR_LEN as _,
825 m as _,
826 MetricType::InnerProduct,
827 )
828 .unwrap();
829
830 faiss_hnsw
831 .add(F32::inner_slice(&hnsw_builder.vector_store.vector_payload))
832 .unwrap();
833 if VERBOSE {
838 let faiss_hnsw = faiss_hnsw.hnsw();
839 let (entry_point, max_level) = faiss_hnsw.entry_point().unwrap();
840 println!("faiss hnsw entry_point: {} {}", entry_point, max_level);
841 let levels = faiss_hnsw.levels_raw();
842 println!("entry point level: {}", levels[entry_point]);
843 for level in 0..=max_level {
844 let neighbors = faiss_hnsw.neighbors_raw(entry_point, level);
845 println!("entry point level {} neighbors {:?}", level, neighbors);
846 }
847 }
848 println!(
849 "faiss hnsw build time: {:?}",
850 faiss_hnsw_start_time.elapsed()
851 );
852
853 let queries = (0..QUERY_COUNT)
856 .map(|_| gen_vector(VECTOR_LEN))
857 .collect_vec();
858 let expected = queries
859 .iter()
860 .map(|query| {
861 let mut nearest_builder =
862 NearestBuilder::<'_, _, InnerProductDistance>::new(query.to_ref(), TOP_N);
863 nearest_builder.add(
864 input
865 .iter()
866 .map(|(vec, info)| (vec.to_ref(), info.as_ref())),
867 |_, _, info| Bytes::copy_from_slice(info),
868 );
869 nearest_builder.finish()
870 })
871 .collect_vec();
872 let faiss_start_time = Instant::now();
873 let repeat_query = if cfg!(debug_assertions) { 1 } else { 60 };
874 println!("start faiss query");
875 let faiss_actual = repeat_with(|| queries.iter().enumerate())
876 .take(repeat_query)
877 .flatten()
878 .map(|(i, query)| {
879 let start_time = Instant::now();
880 let actual = faiss_hnsw
881 .assign(query.as_raw_slice(), TOP_N)
882 .unwrap()
883 .labels
884 .into_iter()
885 .filter_map(|i| i.get().map(|i| gen_info(i as _)))
886 .collect_vec();
887 let recall = recall(&actual, &expected[i]);
888 (start_time.elapsed(), recall)
889 })
890 .collect_vec();
891 let faiss_query_time = faiss_start_time.elapsed();
892 println!("start query");
893 let actuals = EF_SEARCH_LIST
894 .iter()
895 .map(|&ef_search| {
896 let start_time = Instant::now();
897 let actuals = repeat_with(|| queries.iter().enumerate())
898 .take(repeat_query)
899 .flatten()
900 .map(|(i, query)| {
901 let start_time = Instant::now();
902 let (actual, stats) = block_on(nearest::<_, InnerProductDistance, _>(
903 &hnsw_builder.vector_store,
904 &mut (),
905 hnsw_builder.graph.as_ref().unwrap(),
906 query.to_ref(),
907 |_, _, info| Bytes::copy_from_slice(info),
908 ef_search,
909 TOP_N,
910 ))
911 .unwrap();
912 if VERBOSE {
913 println!("stats: {:?}", stats);
914 }
915 let recall = recall(&actual, &expected[i]);
916 (start_time.elapsed(), recall)
917 })
918 .collect_vec();
919 (actuals, start_time.elapsed())
920 })
921 .collect_vec();
922 if VERBOSE {
923 for i in 0..20 {
924 for elapsed in [&faiss_actual]
925 .into_iter()
926 .chain(actuals.iter().map(|(actual, _)| actual))
927 .map(|actual| actual[i].0)
928 {
929 print!("{:?}\t", elapsed);
930 }
931 println!();
932 for recall in [&faiss_actual]
933 .into_iter()
934 .chain(actuals.iter().map(|(actual, _)| actual))
935 .map(|actual| actual[i].1)
936 {
937 print!("{}\t", recall);
938 }
939 println!();
940 }
941 }
942 fn avg_recall(actual: &Vec<(Duration, f32)>) -> f32 {
943 actual.iter().map(|(_, elapsed)| *elapsed).sum::<f32>() / (actual.len() as f32)
944 }
945 println!("faiss {:?} {}", faiss_query_time, avg_recall(&faiss_actual));
946 for i in 0..EF_SEARCH_LIST.len() {
947 println!(
948 "ef_search[{}] {:?} {}",
949 EF_SEARCH_LIST[i],
950 actuals[i].1,
951 avg_recall(&actuals[i].0)
952 );
953 }
954 }
955
956 #[cfg(not(madsim))]
958 #[tokio::test]
959 async fn hnsw_insert_graph_visits_expected_upper_layers() {
960 use rand::SeedableRng;
961 use rand::rngs::StdRng;
962
963 use super::__hnsw_test_hooks as hooks;
964
965 let dim = 8;
967 let options = opts(8, 16, 8); for seed in 1u64..=200 {
972 let mut hnsw: HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, TestL2, StdRng> =
974 HnswBuilder::new(dim, StdRng::seed_from_u64(seed), options);
975
976 let v0 = gen_vector(dim);
978 let _ = hnsw
979 .insert(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0))
980 .await;
981
982 let g = hnsw.graph.as_ref().unwrap();
984 let entry_idx = g.entrypoint();
985 let entry_top_level_idx = g.node_num_levels(entry_idx) - 1;
986
987 if entry_top_level_idx < 2 {
989 continue; }
991
992 hooks::clear_levels();
994 let v1 = gen_vector(dim);
995 let _ = hnsw
996 .insert(VectorRef::from_slice_unchecked(v1.as_slice()), &gen_info(1))
997 .await;
998
999 let g = hnsw.graph.as_ref().unwrap();
1001 let new_idx = g.len() - 1;
1002 let node_top_level_idx = g.node_num_levels(new_idx) - 1;
1003
1004 if node_top_level_idx > entry_top_level_idx {
1007 continue;
1008 }
1009
1010 let expected: Vec<usize> = ((node_top_level_idx + 1)..=entry_top_level_idx)
1012 .rev()
1013 .collect();
1014
1015 let visited = hooks::take_levels();
1017 let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
1019
1020 assert_eq!(
1021 upper, expected,
1022 "seed={seed}, entry_top_level_idx={entry_top_level_idx}, node_top_level_idx={node_top_level_idx}"
1023 );
1024 return; }
1026
1027 panic!(
1028 "could not find a suitable seed (entry_top_level_idx>=2 and node_top<=entry_top_level_idx) within the search window"
1029 );
1030 }
1031
1032 #[cfg(not(madsim))]
1034 #[tokio::test]
1035 async fn hnsw_nearest_visits_expected_upper_layers() {
1036 use super::__hnsw_test_hooks as hooks;
1037
1038 let dim = 1;
1040 let mut store = InMemoryVectorStore::new(dim);
1041 let v0 = gen_vector(dim);
1042 store.add(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0));
1043
1044 let graph = HnswGraphBuilder {
1045 entrypoint: 0,
1046 nodes: vec![VectorHnswNode {
1047 level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
1048 }],
1049 level_node_count: vec![0, 0, 1],
1050 };
1051
1052 let q = gen_vector(dim);
1054
1055 hooks::clear_levels();
1056
1057 let Ok((_out, _stats)) = nearest::<usize, TestL2, _>(
1059 &store,
1060 &mut (),
1061 &graph,
1062 VectorRef::from_slice_unchecked(q.as_slice()),
1063 |_v, _d, _info| 0usize,
1064 4,
1065 1,
1066 )
1067 .await;
1068
1069 let visited = hooks::take_levels();
1070
1071 let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
1073
1074 assert_eq!(
1076 upper,
1077 vec![2, 1],
1078 "nearest should visit levels [2, 1] top-down before level 0"
1079 );
1080 }
1081
1082 #[cfg(not(madsim))]
1083 #[test]
1084 fn hnsw_vector_hnsw_node_level_returns_count() {
1085 let node = VectorHnswNode {
1087 level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
1088 };
1089
1090 assert_eq!(
1092 node.num_levels(),
1093 node.level_neighbours.len(),
1094 "VectorHnswNode::num_levels() must return the count of levels, \
1095 not the maximum level index"
1096 );
1097 }
1098
1099 #[cfg(not(madsim))]
1100 #[test]
1101 fn hnsw_graph_node_level_returns_count() {
1102 let graph = HnswGraphBuilder {
1104 entrypoint: 0,
1105 nodes: vec![VectorHnswNode {
1106 level_neighbours: (0..4).map(|_| BoundedNearest::new(0)).collect(),
1107 }],
1108 level_node_count: vec![0, 0, 0, 1],
1109 };
1110
1111 assert_eq!(
1112 graph.node_num_levels(0),
1113 4,
1114 "HnswGraph::node_level() must return the COUNT of levels, not the max index"
1115 );
1116 }
1117}