1use std::cmp::{max, min};
16use std::marker::PhantomData;
17
18use faiss::index::hnsw::Hnsw;
19use rand::Rng;
20use rand::distr::uniform::{UniformFloat, UniformSampler};
21use risingwave_pb::hummock::PbHnswGraph;
22use risingwave_pb::hummock::hnsw_graph::{PbHnswLevel, PbHnswNeighbor, PbHnswNode};
23
24use crate::hummock::HummockResult;
25use crate::vector::utils::{BoundedNearest, MinDistanceHeap};
26use crate::vector::{
27 MeasureDistance, MeasureDistanceBuilder, OnNearestItem, VectorDistance, VectorItem, VectorRef,
28};
29
30#[derive(Copy, Clone)]
31pub struct HnswBuilderOptions {
32 pub m: usize,
33 pub ef_construction: usize,
34 pub max_level: usize,
35}
36
37pub fn level_m(m: usize, level: usize) -> usize {
38 if level == 0 { 2 * m } else { m }
41}
42
43impl HnswBuilderOptions {
44 fn m_l(&self) -> f32 {
45 1.0 / (self.m as f32).ln()
46 }
47}
48
49fn gen_level(options: &HnswBuilderOptions, rng: &mut impl Rng) -> usize {
50 let level = (-UniformFloat::<f32>::sample_single(0.0, 1.0, rng)
51 .unwrap()
52 .ln()
53 * options.m_l())
54 .floor() as usize;
55 min(level, options.max_level)
56}
57
58pub(crate) fn new_node(options: &HnswBuilderOptions, rng: &mut impl Rng) -> VectorHnswNode {
59 let level = gen_level(options, rng);
60 let mut level_neighbours = Vec::with_capacity(level + 1);
61 level_neighbours
62 .extend((0..=level).map(|level| BoundedNearest::new(level_m(options.m, level))));
63 VectorHnswNode { level_neighbours }
64}
65
66pub(crate) struct VectorHnswNode {
67 level_neighbours: Vec<BoundedNearest<usize>>,
68}
69
70impl VectorHnswNode {
71 fn num_levels(&self) -> usize {
73 self.level_neighbours.len()
74 }
75}
76
77struct InMemoryVectorStore {
78 dimension: usize,
79 vector_payload: Vec<VectorItem>,
80 info_payload: Vec<u8>,
81 info_offsets: Vec<usize>,
82}
83
84impl InMemoryVectorStore {
85 fn new(dimension: usize) -> Self {
86 Self {
87 dimension,
88 vector_payload: vec![],
89 info_payload: Default::default(),
90 info_offsets: vec![],
91 }
92 }
93
94 fn len(&self) -> usize {
95 self.info_offsets.len()
96 }
97
98 fn vec_ref(&self, idx: usize) -> VectorRef<'_> {
99 assert!(idx < self.info_offsets.len());
100 let start = idx * self.dimension;
101 let end = start + self.dimension;
102 VectorRef::from_slice_unchecked(&self.vector_payload[start..end])
103 }
104
105 fn info(&self, idx: usize) -> &[u8] {
106 let start = self.info_offsets[idx];
107 let end = if idx < self.info_offsets.len() - 1 {
108 self.info_offsets[idx + 1]
109 } else {
110 self.info_payload.len()
111 };
112 &self.info_payload[start..end]
113 }
114
115 fn add(&mut self, vec: VectorRef<'_>, info: &[u8]) {
116 assert_eq!(vec.dimension(), self.dimension);
117
118 self.vector_payload.extend_from_slice(vec.as_slice());
119 let offset = self.info_payload.len();
120 self.info_payload.extend_from_slice(info);
121 self.info_offsets.push(offset);
122 }
123}
124
125pub trait VectorAccessor {
126 fn vec_ref(&self) -> VectorRef<'_>;
127
128 fn info(&self) -> &[u8];
129}
130
131pub trait VectorStore: 'static {
132 type Accessor<'a>: VectorAccessor + 'a
133 where
134 Self: 'a;
135 async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>>;
136}
137
138pub struct InMemoryVectorStoreAccessor<'a> {
139 vector_store_impl: &'a InMemoryVectorStore,
140 idx: usize,
141}
142
143impl VectorAccessor for InMemoryVectorStoreAccessor<'_> {
144 fn vec_ref(&self) -> VectorRef<'_> {
145 self.vector_store_impl.vec_ref(self.idx)
146 }
147
148 fn info(&self) -> &[u8] {
149 self.vector_store_impl.info(self.idx)
150 }
151}
152
153impl VectorStore for InMemoryVectorStore {
154 type Accessor<'a> = InMemoryVectorStoreAccessor<'a>;
155
156 async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>> {
157 Ok(InMemoryVectorStoreAccessor {
158 vector_store_impl: self,
159 idx,
160 })
161 }
162}
163
164#[expect(clippy::len_without_is_empty)]
165pub trait HnswGraph {
166 fn entrypoint(&self) -> usize;
167 fn len(&self) -> usize;
168 fn node_num_levels(&self, idx: usize) -> usize;
170 fn node_neighbours(
171 &self,
172 idx: usize,
173 level: usize,
174 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_;
175}
176
177impl HnswGraph for PbHnswGraph {
178 fn entrypoint(&self) -> usize {
179 self.entrypoint_id as _
180 }
181
182 fn len(&self) -> usize {
183 self.nodes.len()
184 }
185
186 fn node_num_levels(&self, idx: usize) -> usize {
187 self.nodes[idx].levels.len()
188 }
189
190 fn node_neighbours(
191 &self,
192 idx: usize,
193 level: usize,
194 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
195 self.nodes[idx]
196 .levels
197 .get(level)
198 .into_iter()
199 .flat_map(|level| {
200 level
201 .neighbors
202 .iter()
203 .map(|neighbor| (neighbor.vector_id as usize, neighbor.distance))
204 })
205 }
206}
207
208pub struct HnswGraphBuilder {
209 entrypoint: usize,
211 nodes: Vec<VectorHnswNode>,
212}
213
214impl HnswGraphBuilder {
215 pub(crate) fn first(node: VectorHnswNode) -> Self {
216 Self {
217 entrypoint: 0,
218 nodes: vec![node],
219 }
220 }
221
222 pub fn to_protobuf(&self) -> PbHnswGraph {
223 let mut nodes = Vec::with_capacity(self.nodes.len());
224 for node in &self.nodes {
225 let mut levels = Vec::with_capacity(node.num_levels());
226 for level in &node.level_neighbours {
227 let mut neighbors = Vec::with_capacity(level.len());
228 for (distance, &neighbor_index) in level {
229 neighbors.push(PbHnswNeighbor {
230 vector_id: neighbor_index as u64,
231 distance,
232 });
233 }
234 levels.push(PbHnswLevel { neighbors });
235 }
236 nodes.push(PbHnswNode { levels });
237 }
238 PbHnswGraph {
239 entrypoint_id: self.entrypoint as u64,
240 nodes,
241 }
242 }
243
244 pub fn from_protobuf(pb: &PbHnswGraph, m: usize) -> Self {
245 let entrypoint = pb.entrypoint_id as usize;
246 let nodes = pb
247 .nodes
248 .iter()
249 .map(|node| {
250 let level_neighbours = node
251 .levels
252 .iter()
253 .enumerate()
254 .map(|(level_idx, level)| {
255 let level_m = level_m(m, level_idx);
256 let mut nearest = BoundedNearest::new(level_m);
257 for neighbor in &level.neighbors {
258 nearest.insert(neighbor.distance, || neighbor.vector_id as _);
259 }
260 nearest
261 })
262 .collect();
263 VectorHnswNode { level_neighbours }
264 })
265 .collect();
266 Self { entrypoint, nodes }
267 }
268}
269
270impl HnswGraph for HnswGraphBuilder {
271 fn entrypoint(&self) -> usize {
272 self.entrypoint
273 }
274
275 fn len(&self) -> usize {
276 self.nodes.len()
277 }
278
279 fn node_num_levels(&self, idx: usize) -> usize {
280 self.nodes[idx].num_levels()
281 }
282
283 fn node_neighbours(
284 &self,
285 idx: usize,
286 level: usize,
287 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
288 (&self.nodes[idx].level_neighbours[level])
289 .into_iter()
290 .map(|(distance, &neighbour_index)| (neighbour_index, distance))
291 }
292}
293
294pub struct HnswBuilder<V: VectorStore, G: HnswGraph, M: MeasureDistanceBuilder, R: Rng> {
295 options: HnswBuilderOptions,
296
297 vector_store: V,
299 graph: Option<G>,
300
301 rng: R,
303 _measure: PhantomData<M>,
304}
305
306#[derive(Default, Debug)]
307pub struct HnswStats {
308 distances_computed: usize,
309 nhops: usize,
310}
311
312struct VecSet {
313 payload: Vec<bool>,
315}
316
317impl VecSet {
318 fn new(size: usize) -> Self {
319 Self {
320 payload: vec![false; size],
321 }
322 }
323
324 fn set(&mut self, idx: usize) {
325 self.payload[idx] = true;
326 }
327
328 fn is_set(&self, idx: usize) -> bool {
329 self.payload[idx]
330 }
331
332 fn reset(&mut self) {
333 self.payload.fill(false);
334 }
335}
336
337impl<M: MeasureDistanceBuilder, R: Rng> HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, M, R> {
338 pub fn new(dimension: usize, rng: R, options: HnswBuilderOptions) -> Self {
339 Self {
340 options,
341 graph: None,
342 vector_store: InMemoryVectorStore::new(dimension),
343 rng,
344 _measure: Default::default(),
345 }
346 }
347
348 pub fn with_faiss_hnsw(self, faiss_hnsw: Hnsw<'_>) -> Self {
349 assert_eq!(self.vector_store.len(), faiss_hnsw.levels_raw().len());
350 let (entry_point, _max_level) = faiss_hnsw.entry_point().unwrap();
351 let levels = faiss_hnsw.levels_raw();
352 let Some(graph) = &self.graph else {
353 assert_eq!(levels.len(), 0);
354 return Self::new(self.vector_store.dimension, self.rng, self.options);
355 };
356 assert_eq!(levels.len(), graph.nodes.len());
357 let mut nodes = Vec::with_capacity(graph.nodes.len());
358 for (node, level_count) in levels.iter().enumerate() {
359 let level_count = *level_count as usize;
360 let mut level_neighbors = Vec::with_capacity(level_count);
361 for level_idx in 0..level_count {
362 let neighbors = faiss_hnsw.neighbors_raw(node, level_idx);
363 let mut nearest_neighbors = BoundedNearest::new(max(neighbors.len(), 1));
364 for &neighbor in neighbors {
365 nearest_neighbors.insert(
366 M::distance(
367 self.vector_store.vec_ref(node),
368 self.vector_store.vec_ref(neighbor as _),
369 ),
370 || neighbor as _,
371 );
372 }
373 level_neighbors.push(nearest_neighbors);
374 }
375 nodes.push(VectorHnswNode {
376 level_neighbours: level_neighbors,
377 });
378 }
379 Self {
380 options: self.options,
381 graph: Some(HnswGraphBuilder {
382 entrypoint: entry_point,
383 nodes,
384 }),
385 vector_store: self.vector_store,
386 rng: self.rng,
387 _measure: Default::default(),
388 }
389 }
390
391 pub fn print_graph(&self) {
392 let Some(graph) = &self.graph else {
393 println!("empty graph");
394 return;
395 };
396 println!(
397 "entrypoint {} has {} levels",
398 graph.entrypoint,
399 graph.nodes[graph.entrypoint].num_levels()
400 );
401 for (i, node) in graph.nodes.iter().enumerate() {
402 println!("node {} has {} levels", i, node.num_levels());
403 for level in 0..node.num_levels() {
404 print!("level {}: ", level);
405 for (_, &neighbor) in &node.level_neighbours[level] {
406 print!("{} ", neighbor);
407 }
408 println!()
409 }
410 }
411 }
412
413 pub async fn insert(&mut self, vec: VectorRef<'_>, info: &[u8]) -> HummockResult<HnswStats> {
414 let node = new_node(&self.options, &mut self.rng);
415 let stat = if let Some(graph) = &mut self.graph {
416 insert_graph::<M>(
417 &self.vector_store,
418 graph,
419 node,
420 vec,
421 self.options.ef_construction,
422 )
423 .await?
424 } else {
425 self.graph = Some(HnswGraphBuilder::first(node));
426 HnswStats::default()
427 };
428 self.vector_store.add(vec, info);
429 Ok(stat)
430 }
431}
432
433pub(crate) async fn insert_graph<M: MeasureDistanceBuilder>(
434 vector_store: &impl VectorStore,
435 graph: &mut HnswGraphBuilder,
436 mut node: VectorHnswNode,
437 vec: VectorRef<'_>,
438 ef_construction: usize,
439) -> HummockResult<HnswStats> {
440 {
441 let mut stats = HnswStats::default();
442 let mut visited = VecSet::new(graph.nodes.len());
443 let entrypoint_index = graph.entrypoint();
444 let measure = M::new(vec);
445 let mut entrypoints = BoundedNearest::new(1);
446
447 entrypoints.insert(
448 measure.measure(vector_store.get_vector(entrypoint_index).await?.vec_ref()),
449 || (entrypoint_index, ()),
450 );
451 stats.distances_computed += 1;
452
453 let entry_top_level_idx = graph.nodes[entrypoint_index].num_levels() - 1;
455 let node_top_level_idx = node.num_levels() - 1;
456
457 for level_idx in ((node_top_level_idx + 1)..=entry_top_level_idx).rev() {
458 entrypoints = search_layer(
459 vector_store,
460 &*graph,
461 &measure,
462 |_, _, _| (),
463 entrypoints,
464 level_idx,
465 1,
466 &mut stats,
467 &mut visited,
468 )
469 .await?;
470 }
471
472 let start_level_idx = min(entry_top_level_idx, node_top_level_idx);
474 for level_idx in (0..=start_level_idx).rev() {
475 entrypoints = search_layer(
476 vector_store,
477 &*graph,
478 &measure,
479 |_, _, _| (),
480 entrypoints,
481 level_idx,
482 ef_construction,
483 &mut stats,
484 &mut visited,
485 )
486 .await?;
487 let level_neighbour = &mut node.level_neighbours[level_idx];
488 for (neighbour_distance, &(neighbour_index, _)) in &entrypoints {
489 level_neighbour.insert(neighbour_distance, || neighbour_index);
490 }
491 }
492
493 let vector_index = graph.nodes.len();
494 for (level_index, level) in node.level_neighbours.iter().enumerate() {
495 for (neighbour_distance, &neighbour_index) in level {
496 graph.nodes[neighbour_index].level_neighbours[level_index]
497 .insert(neighbour_distance, || vector_index);
498 }
499 }
500 if graph.nodes[entrypoint_index].num_levels() < node.num_levels() {
501 graph.entrypoint = vector_index;
502 }
503 graph.nodes.push(node);
504 Ok(stats)
505 }
506}
507
508pub async fn nearest<O: Send, M: MeasureDistanceBuilder>(
509 vector_store: &impl VectorStore,
510 graph: &impl HnswGraph,
511 vec: VectorRef<'_>,
512 on_nearest_fn: impl OnNearestItem<O>,
513 ef_search: usize,
514 top_n: usize,
515) -> HummockResult<(Vec<O>, HnswStats)> {
516 {
517 let mut stats = HnswStats::default();
520 if ef_search == 0 || top_n == 0 {
521 return Ok((Vec::new(), stats));
522 }
523
524 let entrypoint_index = graph.entrypoint();
525 let measure = M::new(vec);
526 let mut entrypoints = BoundedNearest::new(1);
527 let entrypoint_vector = vector_store.get_vector(entrypoint_index).await?;
528 let entrypoint_distance = measure.measure(entrypoint_vector.vec_ref());
529 entrypoints.insert(entrypoint_distance, || {
530 (
531 entrypoint_index,
532 on_nearest_fn(
533 entrypoint_vector.vec_ref(),
534 entrypoint_distance,
535 entrypoint_vector.info(),
536 ),
537 )
538 });
539 stats.distances_computed += 1;
540 let entry_top_level_idx = graph.node_num_levels(entrypoint_index) - 1;
541 let mut visited = VecSet::new(graph.len());
542 for level_idx in (1..=entry_top_level_idx).rev() {
543 entrypoints = search_layer(
544 vector_store,
545 graph,
546 &measure,
547 &on_nearest_fn,
548 entrypoints,
549 level_idx, 1,
551 &mut stats,
552 &mut visited,
553 )
554 .await?;
555 }
556 entrypoints = search_layer(
557 vector_store,
558 graph,
559 &measure,
560 &on_nearest_fn,
561 entrypoints,
562 0,
563 ef_search,
564 &mut stats,
565 &mut visited,
566 )
567 .await?;
568 Ok((
569 entrypoints.collect_with(|_, (_, output)| output, Some(top_n)),
570 stats,
571 ))
572 }
573}
574
575async fn search_layer<O: Send>(
576 vector_store: &impl VectorStore,
577 graph: &impl HnswGraph,
578 measure: &impl MeasureDistance,
579 on_nearest_fn: impl OnNearestItem<O>,
580 entrypoints: BoundedNearest<(usize, O)>,
581 level_index: usize,
582 ef: usize,
583 stats: &mut HnswStats,
584 visited: &mut VecSet,
585) -> HummockResult<BoundedNearest<(usize, O)>> {
586 #[cfg(test)]
587 {
588 __hnsw_test_hooks::record_level(level_index);
589 }
590 {
591 visited.reset();
592
593 let mut candidates = MinDistanceHeap::with_capacity(ef);
594 for (distance, &(idx, _)) in &entrypoints {
595 visited.set(idx);
596 candidates.push(distance, idx);
597 }
598 let mut nearest = entrypoints;
599 nearest.resize(ef);
600
601 while let Some((c_distance, c_index)) = candidates.pop() {
602 let (f_distance, _) = nearest.furthest().expect("non-empty");
603 if c_distance > f_distance {
604 break;
607 }
608 stats.nhops += 1;
609 for (neighbour_index, _) in graph.node_neighbours(c_index, level_index) {
610 if visited.is_set(neighbour_index) {
611 continue;
612 }
613 visited.set(neighbour_index);
614 let vector = vector_store.get_vector(neighbour_index).await?;
615 let info = vector.info();
616
617 let distance = measure.measure(vector.vec_ref());
618
619 stats.distances_computed += 1;
620
621 let mut added = false;
622 let added = &mut added;
623 nearest.insert(distance, || {
624 *added = true;
625 (
626 neighbour_index,
627 on_nearest_fn(vector.vec_ref(), distance, info),
628 )
629 });
630 if *added {
631 candidates.push(distance, neighbour_index);
632 }
633 }
634 }
635
636 Ok(nearest)
637 }
638}
639
640#[cfg(test)]
641mod __hnsw_test_hooks {
642 use std::cell::RefCell;
643
644 thread_local! {
645 #[allow(clippy::missing_const_for_thread_local)]
646 static LEVELS: RefCell<Vec<usize>> = RefCell::new(Vec::new());
647 }
648
649 pub fn record_level(level: usize) {
650 LEVELS.with(|v| v.borrow_mut().push(level));
651 }
652
653 pub fn take_levels() -> Vec<usize> {
654 LEVELS.with(|v| std::mem::take(&mut *v.borrow_mut()))
655 }
656
657 pub fn clear_levels() {
658 LEVELS.with(|v| v.borrow_mut().clear());
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use std::collections::HashSet;
665 use std::iter::repeat_with;
666 use std::time::{Duration, Instant};
667
668 use bytes::Bytes;
669 use faiss::{ConcurrentIndex, Index, MetricType};
670 use futures::executor::block_on;
671 use itertools::Itertools;
672 use rand::SeedableRng;
673 use rand::prelude::StdRng;
674 use risingwave_common::types::F32;
675 use risingwave_common::util::iter_util::ZipEqDebug;
676 use risingwave_common::vector::distance::InnerProductDistance;
677
678 use super::*;
679 use crate::vector::NearestBuilder;
680 use crate::vector::hnsw::{HnswBuilder, HnswBuilderOptions, nearest};
681 use crate::vector::test_utils::{gen_info, gen_vector};
682
683 fn recall(actual: &Vec<Bytes>, expected: &Vec<Bytes>) -> f32 {
684 let expected: HashSet<_> = expected.iter().map(|b| b.as_ref()).collect();
685 (actual
686 .iter()
687 .filter(|info| expected.contains(info.as_ref()))
688 .count() as f32)
689 / (expected.len() as f32)
690 }
691
692 struct TestL2;
694
695 struct TestL2Measure<'a> {
696 q: VectorRef<'a>,
697 }
698
699 impl MeasureDistanceBuilder for TestL2 {
700 type Measure<'a> = TestL2Measure<'a>;
701
702 fn new<'a>(q: VectorRef<'a>) -> Self::Measure<'a> {
703 TestL2Measure { q }
704 }
705
706 fn distance(a: VectorRef<'_>, b: VectorRef<'_>) -> VectorDistance {
707 a.as_slice()
709 .iter()
710 .zip_eq_debug(b.as_slice().iter())
711 .map(|(&x, &y)| {
712 let xf: f32 = x.into();
714 let yf: f32 = y.into();
715 let d = (xf as f64) - (yf as f64);
716 d * d
717 })
718 .sum::<f64>()
719 }
720 }
721
722 impl<'a> MeasureDistance for TestL2Measure<'a> {
723 fn measure(&self, v: VectorRef<'_>) -> VectorDistance {
724 TestL2::distance(self.q, v)
725 }
726 }
727
728 fn opts(m: usize, efc: usize, max_level: usize) -> HnswBuilderOptions {
729 HnswBuilderOptions {
730 m,
731 ef_construction: efc,
732 max_level,
733 }
734 }
735
736 const VERBOSE: bool = false;
737 const VECTOR_LEN: usize = 128;
738 const INPUT_COUNT: usize = 20000;
739 const QUERY_COUNT: usize = 5000;
740 const TOP_N: usize = 10;
741 const EF_SEARCH_LIST: &[usize] = &[16];
742 #[cfg(not(madsim))]
745 #[tokio::test]
746 async fn test_hnsw_basic() {
747 let input = (0..INPUT_COUNT)
748 .map(|i| (gen_vector(VECTOR_LEN), gen_info(i)))
749 .collect_vec();
750 let m = 40;
751 let hnsw_start_time = Instant::now();
752 let mut hnsw_builder = HnswBuilder::<_, _, InnerProductDistance, _>::new(
753 VECTOR_LEN,
754 StdRng::seed_from_u64(233),
755 HnswBuilderOptions {
757 m,
758 ef_construction: 40,
759 max_level: 10,
760 },
761 );
762 for (vec, info) in &input {
763 hnsw_builder.insert(vec.to_ref(), info).await.unwrap();
764 }
765 println!("hnsw build time: {:?}", hnsw_start_time.elapsed());
766 if VERBOSE {
767 hnsw_builder.print_graph();
768 }
769
770 let faiss_hnsw_start_time = Instant::now();
771 let mut faiss_hnsw = faiss::index::hnsw::HnswFlatIndex::new(
772 VECTOR_LEN as _,
773 m as _,
774 MetricType::InnerProduct,
775 )
776 .unwrap();
777
778 faiss_hnsw
779 .add(F32::inner_slice(&hnsw_builder.vector_store.vector_payload))
780 .unwrap();
781 if VERBOSE {
786 let faiss_hnsw = faiss_hnsw.hnsw();
787 let (entry_point, max_level) = faiss_hnsw.entry_point().unwrap();
788 println!("faiss hnsw entry_point: {} {}", entry_point, max_level);
789 let levels = faiss_hnsw.levels_raw();
790 println!("entry point level: {}", levels[entry_point]);
791 for level in 0..=max_level {
792 let neighbors = faiss_hnsw.neighbors_raw(entry_point, level);
793 println!("entry point level {} neighbors {:?}", level, neighbors);
794 }
795 }
796 println!(
797 "faiss hnsw build time: {:?}",
798 faiss_hnsw_start_time.elapsed()
799 );
800
801 let queries = (0..QUERY_COUNT)
804 .map(|_| gen_vector(VECTOR_LEN))
805 .collect_vec();
806 let expected = queries
807 .iter()
808 .map(|query| {
809 let mut nearest_builder =
810 NearestBuilder::<'_, _, InnerProductDistance>::new(query.to_ref(), TOP_N);
811 nearest_builder.add(
812 input
813 .iter()
814 .map(|(vec, info)| (vec.to_ref(), info.as_ref())),
815 |_, _, info| Bytes::copy_from_slice(info),
816 );
817 nearest_builder.finish()
818 })
819 .collect_vec();
820 let faiss_start_time = Instant::now();
821 let repeat_query = if cfg!(debug_assertions) { 1 } else { 60 };
822 println!("start faiss query");
823 let faiss_actual = repeat_with(|| queries.iter().enumerate())
824 .take(repeat_query)
825 .flatten()
826 .map(|(i, query)| {
827 let start_time = Instant::now();
828 let actual = faiss_hnsw
829 .assign(query.as_raw_slice(), TOP_N)
830 .unwrap()
831 .labels
832 .into_iter()
833 .filter_map(|i| i.get().map(|i| gen_info(i as _)))
834 .collect_vec();
835 let recall = recall(&actual, &expected[i]);
836 (start_time.elapsed(), recall)
837 })
838 .collect_vec();
839 let faiss_query_time = faiss_start_time.elapsed();
840 println!("start query");
841 let actuals = EF_SEARCH_LIST
842 .iter()
843 .map(|&ef_search| {
844 let start_time = Instant::now();
845 let actuals = repeat_with(|| queries.iter().enumerate())
846 .take(repeat_query)
847 .flatten()
848 .map(|(i, query)| {
849 let start_time = Instant::now();
850 let (actual, stats) = block_on(nearest::<_, InnerProductDistance>(
851 &hnsw_builder.vector_store,
852 hnsw_builder.graph.as_ref().unwrap(),
853 query.to_ref(),
854 |_, _, info| Bytes::copy_from_slice(info),
855 ef_search,
856 TOP_N,
857 ))
858 .unwrap();
859 if VERBOSE {
860 println!("stats: {:?}", stats);
861 }
862 let recall = recall(&actual, &expected[i]);
863 (start_time.elapsed(), recall)
864 })
865 .collect_vec();
866 (actuals, start_time.elapsed())
867 })
868 .collect_vec();
869 if VERBOSE {
870 for i in 0..20 {
871 for elapsed in [&faiss_actual]
872 .into_iter()
873 .chain(actuals.iter().map(|(actual, _)| actual))
874 .map(|actual| actual[i].0)
875 {
876 print!("{:?}\t", elapsed);
877 }
878 println!();
879 for recall in [&faiss_actual]
880 .into_iter()
881 .chain(actuals.iter().map(|(actual, _)| actual))
882 .map(|actual| actual[i].1)
883 {
884 print!("{}\t", recall);
885 }
886 println!();
887 }
888 }
889 fn avg_recall(actual: &Vec<(Duration, f32)>) -> f32 {
890 actual.iter().map(|(_, elapsed)| *elapsed).sum::<f32>() / (actual.len() as f32)
891 }
892 println!("faiss {:?} {}", faiss_query_time, avg_recall(&faiss_actual));
893 for i in 0..EF_SEARCH_LIST.len() {
894 println!(
895 "ef_search[{}] {:?} {}",
896 EF_SEARCH_LIST[i],
897 actuals[i].1,
898 avg_recall(&actuals[i].0)
899 );
900 }
901 }
902
903 #[cfg(not(madsim))]
905 #[tokio::test]
906 async fn hnsw_insert_graph_visits_expected_upper_layers() -> HummockResult<()> {
907 use rand::SeedableRng;
908 use rand::rngs::StdRng;
909
910 use super::__hnsw_test_hooks as hooks;
911
912 let dim = 8;
914 let options = opts(8, 16, 8); for seed in 1u64..=200 {
919 let mut hnsw: HnswBuilder<InMemoryVectorStore, HnswGraphBuilder, TestL2, StdRng> =
921 HnswBuilder::new(dim, StdRng::seed_from_u64(seed), options);
922
923 let v0 = gen_vector(dim);
925 let _ = hnsw
926 .insert(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0))
927 .await?;
928
929 let g = hnsw.graph.as_ref().unwrap();
931 let entry_idx = g.entrypoint();
932 let entry_top_level_idx = g.node_num_levels(entry_idx) - 1;
933
934 if entry_top_level_idx < 2 {
936 continue; }
938
939 hooks::clear_levels();
941 let v1 = gen_vector(dim);
942 let _ = hnsw
943 .insert(VectorRef::from_slice_unchecked(v1.as_slice()), &gen_info(1))
944 .await?;
945
946 let g = hnsw.graph.as_ref().unwrap();
948 let new_idx = g.len() - 1;
949 let node_top_level_idx = g.node_num_levels(new_idx) - 1;
950
951 if node_top_level_idx > entry_top_level_idx {
954 continue;
955 }
956
957 let expected: Vec<usize> = ((node_top_level_idx + 1)..=entry_top_level_idx)
959 .rev()
960 .collect();
961
962 let visited = hooks::take_levels();
964 let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
966
967 assert_eq!(
968 upper, expected,
969 "seed={seed}, entry_top_level_idx={entry_top_level_idx}, node_top_level_idx={node_top_level_idx}"
970 );
971 return Ok(()); }
973
974 panic!(
975 "could not find a suitable seed (entry_top_level_idx>=2 and node_top<=entry_top_level_idx) within the search window"
976 );
977 }
978
979 #[cfg(not(madsim))]
981 #[tokio::test]
982 async fn hnsw_nearest_visits_expected_upper_layers() -> HummockResult<()> {
983 use super::__hnsw_test_hooks as hooks;
984
985 let dim = 1;
987 let mut store = InMemoryVectorStore::new(dim);
988 let v0 = gen_vector(dim);
989 store.add(VectorRef::from_slice_unchecked(v0.as_slice()), &gen_info(0));
990
991 let graph = HnswGraphBuilder {
992 entrypoint: 0,
993 nodes: vec![VectorHnswNode {
994 level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
995 }],
996 };
997
998 let q = gen_vector(dim);
1000
1001 hooks::clear_levels();
1002
1003 let (_out, _stats) = nearest::<usize, TestL2>(
1005 &store,
1006 &graph,
1007 VectorRef::from_slice_unchecked(q.as_slice()),
1008 |_v, _d, _info| 0usize,
1009 4,
1010 1,
1011 )
1012 .await?;
1013
1014 let visited = hooks::take_levels();
1015
1016 let upper: Vec<usize> = visited.into_iter().filter(|&l| l >= 1).collect();
1018
1019 assert_eq!(
1021 upper,
1022 vec![2, 1],
1023 "nearest should visit levels [2, 1] top-down before level 0"
1024 );
1025 Ok(())
1026 }
1027
1028 #[cfg(not(madsim))]
1029 #[test]
1030 fn hnsw_vector_hnsw_node_level_returns_count() {
1031 let node = VectorHnswNode {
1033 level_neighbours: (0..3).map(|_| BoundedNearest::new(0)).collect(),
1034 };
1035
1036 assert_eq!(
1038 node.num_levels(),
1039 node.level_neighbours.len(),
1040 "VectorHnswNode::num_levels() must return the count of levels, \
1041 not the maximum level index"
1042 );
1043 }
1044
1045 #[cfg(not(madsim))]
1046 #[test]
1047 fn hnsw_graph_node_level_returns_count() {
1048 let graph = HnswGraphBuilder {
1050 entrypoint: 0,
1051 nodes: vec![VectorHnswNode {
1052 level_neighbours: (0..4).map(|_| BoundedNearest::new(0)).collect(),
1053 }],
1054 };
1055
1056 assert_eq!(
1057 graph.node_num_levels(0),
1058 4,
1059 "HnswGraph::node_level() must return the COUNT of levels, not the max index"
1060 );
1061 }
1062}