1use std::cmp::min;
16use std::marker::PhantomData;
17
18use faiss::index::hnsw::Hnsw;
19use rand::Rng;
20use rand::distr::uniform::{UniformFloat, UniformSampler};
21
22use crate::hummock::HummockResult;
23use crate::vector::utils::{BoundedNearest, MinDistanceHeap};
24use crate::vector::{
25 MeasureDistance, MeasureDistanceBuilder, OnNearestItem, VectorDistance, VectorInner,
26 VectorItem, VectorRef,
27};
28
29pub struct HnswBuilderOptions {
30 pub m: usize,
31 pub ef_construction: usize,
32 pub max_level: usize,
33}
34
35impl HnswBuilderOptions {
36 fn level_m(&self, level: usize) -> usize {
37 if level == 0 { 2 * self.m } else { self.m }
40 }
41
42 fn m_l(&self) -> f32 {
43 1.0 / (self.m as f32).ln()
44 }
45}
46
47fn gen_level(options: &HnswBuilderOptions, rng: &mut impl Rng) -> usize {
48 let level = (-UniformFloat::<f32>::sample_single(0.0, 1.0, rng)
49 .unwrap()
50 .ln()
51 * options.m_l())
52 .floor() as usize;
53 min(level, options.max_level)
54}
55
56pub(crate) fn new_node(options: &HnswBuilderOptions, rng: &mut impl Rng) -> VectorHnswNode {
57 let level = gen_level(options, rng);
58 let mut level_neighbours = Vec::with_capacity(level);
59 level_neighbours.extend((0..=level).map(|level| BoundedNearest::new(options.level_m(level))));
60 VectorHnswNode { level_neighbours }
61}
62
63pub(crate) struct VectorHnswNode {
64 level_neighbours: Vec<BoundedNearest<usize>>,
65}
66
67impl VectorHnswNode {
68 fn level(&self) -> usize {
69 self.level_neighbours.len()
70 }
71}
72
73struct VectorStoreImpl {
74 vector_len: usize,
75 vector_payload: Vec<VectorItem>,
76 info_payload: Vec<u8>,
77 info_offsets: Vec<usize>,
78}
79
80impl VectorStoreImpl {
81 fn new(vector_len: usize) -> Self {
82 Self {
83 vector_len,
84 vector_payload: vec![],
85 info_payload: Default::default(),
86 info_offsets: vec![],
87 }
88 }
89
90 fn len(&self) -> usize {
91 self.info_offsets.len()
92 }
93
94 fn vec_ref(&self, idx: usize) -> VectorRef<'_> {
95 assert!(idx < self.info_offsets.len());
96 let start = idx * self.vector_len;
97 let end = start + self.vector_len;
98 VectorInner(&self.vector_payload[start..end])
99 }
100
101 fn info(&self, idx: usize) -> &[u8] {
102 let start = self.info_offsets[idx];
103 let end = if idx < self.info_offsets.len() - 1 {
104 self.info_offsets[idx + 1]
105 } else {
106 self.info_payload.len()
107 };
108 &self.info_payload[start..end]
109 }
110
111 fn add(&mut self, vec: VectorRef<'_>, info: &[u8]) {
112 assert_eq!(vec.0.len(), self.vector_len);
113
114 self.vector_payload.extend_from_slice(vec.0);
115 let offset = self.info_payload.len();
116 self.info_payload.extend_from_slice(info);
117 self.info_offsets.push(offset);
118 }
119}
120
121pub trait VectorAccessor {
122 fn vec_ref(&self) -> VectorRef<'_>;
123
124 fn info(&self) -> &[u8];
125}
126
127pub trait VectorStore: 'static {
128 type Accessor<'a>: VectorAccessor + 'a
129 where
130 Self: 'a;
131 async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>>;
132}
133
134pub struct VectorStoreImplAccessor<'a> {
135 vector_store_impl: &'a VectorStoreImpl,
136 idx: usize,
137}
138
139impl VectorAccessor for VectorStoreImplAccessor<'_> {
140 fn vec_ref(&self) -> VectorRef<'_> {
141 self.vector_store_impl.vec_ref(self.idx)
142 }
143
144 fn info(&self) -> &[u8] {
145 self.vector_store_impl.info(self.idx)
146 }
147}
148
149impl VectorStore for VectorStoreImpl {
150 type Accessor<'a> = VectorStoreImplAccessor<'a>;
151
152 async fn get_vector(&self, idx: usize) -> HummockResult<Self::Accessor<'_>> {
153 Ok(VectorStoreImplAccessor {
154 vector_store_impl: self,
155 idx,
156 })
157 }
158}
159
160#[expect(clippy::len_without_is_empty)]
161pub trait HnswGraph {
162 fn entrypoint(&self) -> usize;
163 fn len(&self) -> usize;
164 fn node_level(&self, idx: usize) -> usize;
165 fn node_neighbours(
166 &self,
167 idx: usize,
168 level: usize,
169 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_;
170}
171
172pub struct HnswGraphBuilder {
173 entrypoint: usize,
175 nodes: Vec<VectorHnswNode>,
176}
177
178impl HnswGraphBuilder {
179 pub(crate) fn first(node: VectorHnswNode) -> Self {
180 Self {
181 entrypoint: 0,
182 nodes: vec![node],
183 }
184 }
185}
186
187impl HnswGraph for HnswGraphBuilder {
188 fn entrypoint(&self) -> usize {
189 self.entrypoint
190 }
191
192 fn len(&self) -> usize {
193 self.nodes.len()
194 }
195
196 fn node_level(&self, idx: usize) -> usize {
197 self.nodes[idx].level()
198 }
199
200 fn node_neighbours(
201 &self,
202 idx: usize,
203 level: usize,
204 ) -> impl Iterator<Item = (usize, VectorDistance)> + '_ {
205 (&self.nodes[idx].level_neighbours[level])
206 .into_iter()
207 .map(|(distance, &neighbour_index)| (neighbour_index, distance))
208 }
209}
210
211pub struct HnswBuilder<V: VectorStore, G: HnswGraph, M: MeasureDistanceBuilder, R: Rng> {
212 options: HnswBuilderOptions,
213
214 vector_store: V,
216 graph: Option<G>,
217
218 rng: R,
220 _measure: PhantomData<M>,
221}
222
223#[derive(Default, Debug)]
224pub struct HnswStats {
225 distances_computed: usize,
226 nhops: usize,
227}
228
229struct VecSet {
230 payload: Vec<bool>,
232}
233
234impl VecSet {
235 fn new(size: usize) -> Self {
236 Self {
237 payload: vec![false; size],
238 }
239 }
240
241 fn set(&mut self, idx: usize) {
242 self.payload[idx] = true;
243 }
244
245 fn is_set(&self, idx: usize) -> bool {
246 self.payload[idx]
247 }
248
249 fn reset(&mut self) {
250 self.payload.fill(false);
251 }
252}
253
254impl<M: MeasureDistanceBuilder, R: Rng> HnswBuilder<VectorStoreImpl, HnswGraphBuilder, M, R> {
255 pub fn new(vector_len: usize, rng: R, options: HnswBuilderOptions) -> Self {
256 Self {
257 options,
258 graph: None,
259 vector_store: VectorStoreImpl::new(vector_len),
260 rng,
261 _measure: Default::default(),
262 }
263 }
264
265 pub fn with_faiss_hnsw(self, faiss_hnsw: Hnsw<'_>) -> Self {
266 assert_eq!(self.vector_store.len(), faiss_hnsw.levels_raw().len());
267 let (entry_point, _max_level) = faiss_hnsw.entry_point().unwrap();
268 let levels = faiss_hnsw.levels_raw();
269 let Some(graph) = &self.graph else {
270 assert_eq!(levels.len(), 0);
271 return Self::new(self.vector_store.vector_len, self.rng, self.options);
272 };
273 assert_eq!(levels.len(), graph.nodes.len());
274 let mut nodes = Vec::with_capacity(graph.nodes.len());
275 for (node, level_count) in levels.iter().enumerate() {
276 let level_count = *level_count as usize;
277 let mut level_neighbors = Vec::with_capacity(level_count);
278 for level in 0..level_count {
279 let neighbors = faiss_hnsw.neighbors_raw(node, level);
280 let mut nearest_neighbors = BoundedNearest::new(neighbors.len());
281 for &neighbor in neighbors {
282 nearest_neighbors.insert(
283 M::distance(
284 self.vector_store.vec_ref(node),
285 self.vector_store.vec_ref(neighbor as _),
286 ),
287 || neighbor as _,
288 );
289 }
290 level_neighbors.push(nearest_neighbors);
291 }
292 nodes.push(VectorHnswNode {
293 level_neighbours: level_neighbors,
294 });
295 }
296 Self {
297 options: self.options,
298 graph: Some(HnswGraphBuilder {
299 entrypoint: entry_point,
300 nodes,
301 }),
302 vector_store: self.vector_store,
303 rng: self.rng,
304 _measure: Default::default(),
305 }
306 }
307
308 pub fn print_graph(&self) {
309 let Some(graph) = &self.graph else {
310 println!("empty graph");
311 return;
312 };
313 println!(
314 "entrypoint {} in level {}",
315 graph.entrypoint,
316 graph.nodes[graph.entrypoint].level()
317 );
318 for (i, node) in graph.nodes.iter().enumerate() {
319 println!("node {} has {} levels", i, node.level());
320 for level in 0..node.level() {
321 print!("level {}: ", level);
322 for (_, &neighbor) in &node.level_neighbours[level] {
323 print!("{} ", neighbor);
324 }
325 println!()
326 }
327 }
328 }
329
330 pub async fn insert(&mut self, vec: VectorRef<'_>, info: &[u8]) -> HummockResult<HnswStats> {
331 let node = new_node(&self.options, &mut self.rng);
332 let stat = if let Some(graph) = &mut self.graph {
333 insert_graph::<M>(
334 &self.vector_store,
335 graph,
336 node,
337 vec,
338 self.options.ef_construction,
339 )
340 .await?
341 } else {
342 self.graph = Some(HnswGraphBuilder::first(node));
343 HnswStats::default()
344 };
345 self.vector_store.add(vec, info);
346 Ok(stat)
347 }
348}
349
350pub(crate) async fn insert_graph<M: MeasureDistanceBuilder>(
351 vector_store: &impl VectorStore,
352 graph: &mut HnswGraphBuilder,
353 mut node: VectorHnswNode,
354 vec: VectorRef<'_>,
355 ef_construction: usize,
356) -> HummockResult<HnswStats> {
357 {
358 let mut stats = HnswStats::default();
359 let entrypoint_index = graph.entrypoint();
360 let measure = M::new(vec);
361 let mut entrypoints = BoundedNearest::new(1);
362 entrypoints.insert(
363 measure.measure(vector_store.get_vector(entrypoint_index).await?.vec_ref()),
364 || (entrypoint_index, ()),
365 );
366 let mut visited = VecSet::new(graph.nodes.len());
367 let entrypoint_level = graph.nodes[entrypoint_index].level();
368 {
369 let mut curr_level = entrypoint_level;
370 while curr_level > node.level() + 1 {
371 curr_level -= 1;
372 entrypoints = search_layer(
373 vector_store,
374 &*graph,
375 &measure,
376 |_, _, _| (),
377 entrypoints,
378 curr_level,
379 1,
380 &mut stats,
381 &mut visited,
382 )
383 .await?;
384 }
385 }
386 {
387 let mut curr_level = min(entrypoint_level, node.level());
388 while curr_level > 0 {
389 curr_level -= 1;
390 entrypoints = search_layer(
391 vector_store,
392 &*graph,
393 &measure,
394 |_, _, _| (),
395 entrypoints,
396 curr_level,
397 ef_construction,
398 &mut stats,
399 &mut visited,
400 )
401 .await?;
402 let level_neighbour = &mut node.level_neighbours[curr_level];
403 for (neighbour_distance, &(neighbour_index, _)) in &entrypoints {
404 level_neighbour.insert(neighbour_distance, || neighbour_index);
405 }
406 }
407 }
408 let vector_index = graph.nodes.len();
409 for (level_index, level) in node.level_neighbours.iter().enumerate() {
410 for (neighbour_distance, &neighbour_index) in level {
411 graph.nodes[neighbour_index].level_neighbours[level_index]
412 .insert(neighbour_distance, || vector_index);
413 }
414 }
415 if graph.nodes[entrypoint_index].level() < node.level() {
416 graph.entrypoint = vector_index;
417 }
418 graph.nodes.push(node);
419 Ok(stats)
420 }
421}
422
423pub async fn nearest<O: Send, M: MeasureDistanceBuilder>(
424 vector_store: &impl VectorStore,
425 graph: &impl HnswGraph,
426 vec: VectorRef<'_>,
427 on_nearest_fn: impl OnNearestItem<O>,
428 ef_search: usize,
429 top_n: usize,
430) -> HummockResult<(Vec<O>, HnswStats)> {
431 {
432 let entrypoint_index = graph.entrypoint();
433 let measure = M::new(vec);
434 let mut entrypoints = BoundedNearest::new(1);
435 let mut stats = HnswStats::default();
436 let entrypoint_vector = vector_store.get_vector(entrypoint_index).await?;
437 let entrypoint_distance = measure.measure(entrypoint_vector.vec_ref());
438 entrypoints.insert(entrypoint_distance, || {
439 (
440 entrypoint_index,
441 on_nearest_fn(
442 entrypoint_vector.vec_ref(),
443 entrypoint_distance,
444 entrypoint_vector.info(),
445 ),
446 )
447 });
448 stats.distances_computed += 1;
449 let entrypoint_level = graph.node_level(entrypoint_index);
450 let mut visited = VecSet::new(graph.len());
451 {
452 let mut curr_level = entrypoint_level;
453 while curr_level > 1 {
454 curr_level -= 1;
455 entrypoints = search_layer(
456 vector_store,
457 graph,
458 &measure,
459 &on_nearest_fn,
460 entrypoints,
461 curr_level,
462 1,
463 &mut stats,
464 &mut visited,
465 )
466 .await?;
467 }
468 }
469 entrypoints = search_layer(
470 vector_store,
471 graph,
472 &measure,
473 &on_nearest_fn,
474 entrypoints,
475 0,
476 ef_search,
477 &mut stats,
478 &mut visited,
479 )
480 .await?;
481 Ok((
482 entrypoints.collect_with(|(_, output)| output, Some(top_n)),
483 stats,
484 ))
485 }
486}
487
488async fn search_layer<O: Send>(
489 vector_store: &impl VectorStore,
490 graph: &impl HnswGraph,
491 measure: &impl MeasureDistance,
492 on_nearest_fn: impl OnNearestItem<O>,
493 entrypoints: BoundedNearest<(usize, O)>,
494 level_index: usize,
495 ef: usize,
496 stats: &mut HnswStats,
497 visited: &mut VecSet,
498) -> HummockResult<BoundedNearest<(usize, O)>> {
499 {
500 visited.reset();
501 let mut candidates = MinDistanceHeap::with_capacity(ef);
502 for (distance, &(idx, _)) in &entrypoints {
503 visited.set(idx);
504 candidates.push(distance, idx);
505 }
506 let mut nearest = entrypoints;
507 nearest.resize(ef);
508
509 while let Some((c_distance, c_index)) = candidates.pop() {
510 let (f_distance, _) = nearest.furthest().expect("non-empty");
511 if c_distance > f_distance {
512 break;
515 }
516 stats.nhops += 1;
517 for (neighbour_index, _) in graph.node_neighbours(c_index, level_index) {
518 if visited.is_set(neighbour_index) {
519 continue;
520 }
521 visited.set(neighbour_index);
522 let vector = vector_store.get_vector(neighbour_index).await?;
523 let info = vector.info();
524 let distance = measure.measure(vector.vec_ref());
525 stats.distances_computed += 1;
526 let mut added = false;
527 let added = &mut added;
528 nearest.insert(distance, || {
529 *added = true;
530 (
531 neighbour_index,
532 on_nearest_fn(vector.vec_ref(), distance, info),
533 )
534 });
535 if *added {
536 candidates.push(distance, neighbour_index);
537 }
538 }
539 }
540
541 Ok(nearest)
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use std::collections::HashSet;
548 use std::iter::repeat_with;
549 use std::time::{Duration, Instant};
550
551 use bytes::Bytes;
552 use faiss::{ConcurrentIndex, Index, MetricType};
553 use futures::executor::block_on;
554 use itertools::Itertools;
555 use rand::SeedableRng;
556 use rand::prelude::StdRng;
557
558 use crate::vector::NearestBuilder;
559 use crate::vector::distance::InnerProductDistance;
560 use crate::vector::hnsw::{HnswBuilder, HnswBuilderOptions, nearest};
561 use crate::vector::test_utils::{gen_info, gen_vector};
562
563 fn recall(actual: &Vec<Bytes>, expected: &Vec<Bytes>) -> f32 {
564 let expected: HashSet<_> = expected.iter().map(|b| b.as_ref()).collect();
565 (actual
566 .iter()
567 .filter(|info| expected.contains(info.as_ref()))
568 .count() as f32)
569 / (expected.len() as f32)
570 }
571
572 const VERBOSE: bool = false;
573 const VECTOR_LEN: usize = 128;
574 const INPUT_COUNT: usize = 20000;
575 const QUERY_COUNT: usize = 5000;
576 const TOP_N: usize = 10;
577 const EF_SEARCH_LIST: &[usize] = &[16];
578 #[tokio::test]
581 async fn test_hnsw_basic() {
582 let input = (0..INPUT_COUNT)
583 .map(|i| (gen_vector(VECTOR_LEN), gen_info(i)))
584 .collect_vec();
585 let m = 40;
586 let hnsw_start_time = Instant::now();
587 let mut hnsw_builder = HnswBuilder::<_, _, InnerProductDistance, _>::new(
588 VECTOR_LEN,
589 StdRng::seed_from_u64(233),
590 HnswBuilderOptions {
592 m,
593 ef_construction: 40,
594 max_level: 10,
595 },
596 );
597 for (vec, info) in &input {
598 hnsw_builder.insert(vec.to_ref(), info).await.unwrap();
599 }
600 println!("hnsw build time: {:?}", hnsw_start_time.elapsed());
601 if VERBOSE {
602 hnsw_builder.print_graph();
603 }
604
605 let faiss_hnsw_start_time = Instant::now();
606 let mut faiss_hnsw = faiss::index::hnsw::HnswFlatIndex::new(
607 VECTOR_LEN as _,
608 m as _,
609 MetricType::InnerProduct,
610 )
611 .unwrap();
612
613 faiss_hnsw
614 .add(&hnsw_builder.vector_store.vector_payload)
615 .unwrap();
616 if VERBOSE {
621 let faiss_hnsw = faiss_hnsw.hnsw();
622 let (entry_point, max_level) = faiss_hnsw.entry_point().unwrap();
623 println!("faiss hnsw entry_point: {} {}", entry_point, max_level);
624 let levels = faiss_hnsw.levels_raw();
625 println!("entry point level: {}", levels[entry_point]);
626 for level in 0..=max_level {
627 let neighbors = faiss_hnsw.neighbors_raw(entry_point, level);
628 println!("entry point level {} neighbors {:?}", level, neighbors);
629 }
630 }
631 println!(
632 "faiss hnsw build time: {:?}",
633 faiss_hnsw_start_time.elapsed()
634 );
635
636 let queries = (0..QUERY_COUNT)
639 .map(|_| gen_vector(VECTOR_LEN))
640 .collect_vec();
641 let expected = queries
642 .iter()
643 .map(|query| {
644 let mut nearest_builder =
645 NearestBuilder::<'_, _, InnerProductDistance>::new(query.to_ref(), TOP_N);
646 nearest_builder.add(
647 input
648 .iter()
649 .map(|(vec, info)| (vec.to_ref(), info.as_ref())),
650 |_, _, info| Bytes::copy_from_slice(info),
651 );
652 nearest_builder.finish()
653 })
654 .collect_vec();
655 let faiss_start_time = Instant::now();
656 let repeat_query = if cfg!(debug_assertions) { 1 } else { 60 };
657 println!("start faiss query");
658 let faiss_actual = repeat_with(|| queries.iter().enumerate())
659 .take(repeat_query)
660 .flatten()
661 .map(|(i, query)| {
662 let start_time = Instant::now();
663 let actual = faiss_hnsw
664 .assign(&query.0, TOP_N)
665 .unwrap()
666 .labels
667 .into_iter()
668 .filter_map(|i| i.get().map(|i| gen_info(i as _)))
669 .collect_vec();
670 let recall = recall(&actual, &expected[i]);
671 (start_time.elapsed(), recall)
672 })
673 .collect_vec();
674 let faiss_query_time = faiss_start_time.elapsed();
675 println!("start query");
676 let actuals = EF_SEARCH_LIST
677 .iter()
678 .map(|&ef_search| {
679 let start_time = Instant::now();
680 let actuals = repeat_with(|| queries.iter().enumerate())
681 .take(repeat_query)
682 .flatten()
683 .map(|(i, query)| {
684 let start_time = Instant::now();
685 let (actual, stats) = block_on(nearest::<_, InnerProductDistance>(
686 &hnsw_builder.vector_store,
687 hnsw_builder.graph.as_ref().unwrap(),
688 query.to_ref(),
689 |_, _, info| Bytes::copy_from_slice(info),
690 ef_search,
691 TOP_N,
692 ))
693 .unwrap();
694 if VERBOSE {
695 println!("stats: {:?}", stats);
696 }
697 let recall = recall(&actual, &expected[i]);
698 (start_time.elapsed(), recall)
699 })
700 .collect_vec();
701 (actuals, start_time.elapsed())
702 })
703 .collect_vec();
704 if VERBOSE {
705 for i in 0..20 {
706 for elapsed in [&faiss_actual]
707 .into_iter()
708 .chain(actuals.iter().map(|(actual, _)| actual))
709 .map(|actual| actual[i].0)
710 {
711 print!("{:?}\t", elapsed);
712 }
713 println!();
714 for recall in [&faiss_actual]
715 .into_iter()
716 .chain(actuals.iter().map(|(actual, _)| actual))
717 .map(|actual| actual[i].1)
718 {
719 print!("{}\t", recall);
720 }
721 println!();
722 }
723 }
724 fn avg_recall(actual: &Vec<(Duration, f32)>) -> f32 {
725 actual.iter().map(|(_, elapsed)| *elapsed).sum::<f32>() / (actual.len() as f32)
726 }
727 println!("faiss {:?} {}", faiss_query_time, avg_recall(&faiss_actual));
728 for i in 0..EF_SEARCH_LIST.len() {
729 println!(
730 "ef_search[{}] {:?} {}",
731 EF_SEARCH_LIST[i],
732 actuals[i].1,
733 avg_recall(&actuals[i].0)
734 );
735 }
736 }
737}