risingwave_hummock_sdk/
vector_index.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use risingwave_common::catalog::TableId;
18use risingwave_pb::catalog::vector_index_info::Config;
19use risingwave_pb::catalog::{PbFlatIndexConfig, PbHnswFlatIndexConfig};
20use risingwave_pb::common::PbDistanceType;
21use risingwave_pb::hummock::vector_index::PbVariant;
22use risingwave_pb::hummock::vector_index_delta::{
23    PbVectorIndexAdd, PbVectorIndexInit, vector_index_add,
24};
25use risingwave_pb::hummock::{
26    PbFlatIndex, PbFlatIndexAdd, PbHnswFlatIndex, PbHnswFlatIndexAdd, PbHnswGraphFileInfo,
27    PbVectorFileInfo, PbVectorIndex, PbVectorIndexDelta, vector_index_delta,
28};
29
30use crate::{HummockHnswGraphFileId, HummockObjectId, HummockVectorFileId};
31
32#[derive(Clone, Debug, PartialEq)]
33pub struct VectorFileInfo {
34    pub object_id: HummockVectorFileId,
35    pub file_size: u64,
36    pub start_vector_id: usize,
37    pub vector_count: usize,
38    pub meta_offset: usize,
39}
40
41impl From<PbVectorFileInfo> for VectorFileInfo {
42    fn from(pb: PbVectorFileInfo) -> Self {
43        Self {
44            object_id: pb.object_id.into(),
45            file_size: pb.file_size,
46            start_vector_id: pb.start_vector_id.try_into().unwrap(),
47            vector_count: pb.vector_count.try_into().unwrap(),
48            meta_offset: pb.meta_offset.try_into().unwrap(),
49        }
50    }
51}
52
53impl From<VectorFileInfo> for PbVectorFileInfo {
54    fn from(info: VectorFileInfo) -> Self {
55        Self {
56            object_id: info.object_id.inner(),
57            file_size: info.file_size,
58            start_vector_id: info.start_vector_id.try_into().unwrap(),
59            vector_count: info.vector_count.try_into().unwrap(),
60            meta_offset: info.meta_offset.try_into().unwrap(),
61        }
62    }
63}
64
65#[derive(Debug, Clone, PartialEq)]
66pub struct VectorStoreInfo {
67    pub next_vector_id: usize,
68    pub vector_files: Vec<VectorFileInfo>,
69}
70
71#[derive(Clone, Debug, PartialEq)]
72pub struct VectorStoreInfoDelta {
73    pub next_vector_id: usize,
74    pub added_vector_files: Vec<VectorFileInfo>,
75}
76
77impl VectorStoreInfo {
78    fn empty() -> Self {
79        Self {
80            next_vector_id: 0,
81            vector_files: vec![],
82        }
83    }
84
85    fn apply_vector_store_delta(&mut self, delta: &VectorStoreInfoDelta) {
86        for new_vector_file in &delta.added_vector_files {
87            if let Some(latest_vector_file) = self.vector_files.last() {
88                assert!(
89                    new_vector_file.start_vector_id
90                        >= latest_vector_file.start_vector_id + latest_vector_file.vector_count,
91                    "new vector file's start vector id {} should be greater than the last vector file's start vector id {} + vector count {}",
92                    new_vector_file.start_vector_id,
93                    latest_vector_file.start_vector_id,
94                    latest_vector_file.vector_count
95                );
96            }
97            self.vector_files.push(new_vector_file.clone());
98        }
99        self.next_vector_id = delta.next_vector_id;
100        if let Some(latest_vector_file) = self.vector_files.last() {
101            assert!(
102                latest_vector_file.start_vector_id + latest_vector_file.vector_count
103                    <= self.next_vector_id,
104                "next_vector_id {} should be greater than the last vector file's start vector id {} + vector count {}",
105                self.next_vector_id,
106                latest_vector_file.start_vector_id,
107                latest_vector_file.vector_count
108            );
109        }
110    }
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct FlatIndex {
115    pub config: PbFlatIndexConfig,
116    pub vector_store_info: VectorStoreInfo,
117}
118
119impl FlatIndex {
120    fn new(config: &PbFlatIndexConfig) -> FlatIndex {
121        FlatIndex {
122            config: *config,
123            vector_store_info: VectorStoreInfo::empty(),
124        }
125    }
126
127    fn apply_flat_index_add(&mut self, add: &FlatIndexAdd) {
128        self.vector_store_info
129            .apply_vector_store_delta(&add.vector_store_info_delta);
130    }
131}
132
133impl From<PbFlatIndex> for FlatIndex {
134    fn from(pb: PbFlatIndex) -> Self {
135        Self {
136            config: pb.config.unwrap(),
137            vector_store_info: VectorStoreInfo {
138                next_vector_id: pb.next_vector_id.try_into().unwrap(),
139                vector_files: pb
140                    .vector_files
141                    .into_iter()
142                    .map(VectorFileInfo::from)
143                    .collect(),
144            },
145        }
146    }
147}
148impl From<FlatIndex> for PbFlatIndex {
149    fn from(index: FlatIndex) -> Self {
150        Self {
151            config: Some(index.config),
152            next_vector_id: index.vector_store_info.next_vector_id.try_into().unwrap(),
153            vector_files: index
154                .vector_store_info
155                .vector_files
156                .into_iter()
157                .map(PbVectorFileInfo::from)
158                .collect(),
159        }
160    }
161}
162
163#[derive(Debug, Clone, PartialEq)]
164pub struct HnswGraphFileInfo {
165    pub object_id: HummockHnswGraphFileId,
166    pub file_size: u64,
167}
168
169impl From<PbHnswGraphFileInfo> for HnswGraphFileInfo {
170    fn from(pb: PbHnswGraphFileInfo) -> Self {
171        Self {
172            object_id: pb.object_id.into(),
173            file_size: pb.file_size,
174        }
175    }
176}
177
178impl From<HnswGraphFileInfo> for PbHnswGraphFileInfo {
179    fn from(info: HnswGraphFileInfo) -> Self {
180        Self {
181            object_id: info.object_id.inner(),
182            file_size: info.file_size,
183        }
184    }
185}
186
187#[derive(Debug, Clone, PartialEq)]
188pub struct HnswFlatIndex {
189    pub config: PbHnswFlatIndexConfig,
190    pub vector_store_info: VectorStoreInfo,
191    pub graph_file: Option<HnswGraphFileInfo>,
192}
193
194impl HnswFlatIndex {
195    fn new(config: &PbHnswFlatIndexConfig) -> HnswFlatIndex {
196        HnswFlatIndex {
197            config: *config,
198            vector_store_info: VectorStoreInfo::empty(),
199            graph_file: None,
200        }
201    }
202
203    fn apply_hnsw_flat_index_add(&mut self, add: &HnswFlatIndexAdd) {
204        self.vector_store_info
205            .apply_vector_store_delta(&add.vector_store_info_delta);
206        self.graph_file = Some(add.graph_file.clone());
207        if self.graph_file.is_some() {
208            assert!(
209                !self.vector_store_info.vector_files.is_empty(),
210                "HNSW Flat Index must have at least one vector file when a graph file is present"
211            );
212        }
213    }
214}
215
216impl From<PbHnswFlatIndex> for HnswFlatIndex {
217    fn from(pb: PbHnswFlatIndex) -> Self {
218        Self {
219            config: pb.config.unwrap(),
220            vector_store_info: VectorStoreInfo {
221                next_vector_id: pb.next_vector_id.try_into().unwrap(),
222                vector_files: pb
223                    .vector_files
224                    .into_iter()
225                    .map(VectorFileInfo::from)
226                    .collect(),
227            },
228            graph_file: pb.graph_file.map(Into::into),
229        }
230    }
231}
232
233impl From<HnswFlatIndex> for PbHnswFlatIndex {
234    fn from(index: HnswFlatIndex) -> Self {
235        Self {
236            config: Some(index.config),
237            vector_files: index
238                .vector_store_info
239                .vector_files
240                .into_iter()
241                .map(PbVectorFileInfo::from)
242                .collect(),
243            next_vector_id: index.vector_store_info.next_vector_id.try_into().unwrap(),
244            graph_file: index.graph_file.map(Into::into),
245        }
246    }
247}
248
249#[derive(Debug, Clone, PartialEq)]
250pub enum VectorIndexImpl {
251    Flat(FlatIndex),
252    HnswFlat(HnswFlatIndex),
253}
254
255impl From<PbVariant> for VectorIndexImpl {
256    fn from(variant: PbVariant) -> Self {
257        match variant {
258            PbVariant::Flat(flat_index) => Self::Flat(flat_index.into()),
259            PbVariant::HnswFlat(hnsw_flat_index) => Self::HnswFlat(hnsw_flat_index.into()),
260        }
261    }
262}
263
264impl From<VectorIndexImpl> for PbVariant {
265    fn from(index: VectorIndexImpl) -> Self {
266        match index {
267            VectorIndexImpl::Flat(flat_index) => PbVariant::Flat(flat_index.into()),
268            VectorIndexImpl::HnswFlat(hnsw_flat_index) => {
269                PbVariant::HnswFlat(hnsw_flat_index.into())
270            }
271        }
272    }
273}
274
275#[derive(Debug, Clone, PartialEq)]
276pub struct VectorIndex {
277    pub dimension: usize,
278    pub distance_type: PbDistanceType,
279    pub inner: VectorIndexImpl,
280}
281
282impl VectorIndex {
283    pub fn get_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
284        // DO NOT REMOVE THIS LINE
285        // This is to ensure that when adding new variant to `HummockObjectId`,
286        // the compiler will warn us if we forget to handle it here.
287        match HummockObjectId::Sstable(0.into()) {
288            HummockObjectId::Sstable(_) => {}
289            HummockObjectId::VectorFile(_) => {}
290            HummockObjectId::HnswGraphFile(_) => {}
291        };
292        let vector_files = match &self.inner {
293            VectorIndexImpl::Flat(flat) => &flat.vector_store_info.vector_files,
294            VectorIndexImpl::HnswFlat(hnsw_flat) => &hnsw_flat.vector_store_info.vector_files,
295        };
296        let graph_file_object_id = match &self.inner {
297            VectorIndexImpl::Flat(_) => None,
298            VectorIndexImpl::HnswFlat(hnsw_flat) => hnsw_flat.graph_file.as_ref().map(|file| {
299                (
300                    HummockObjectId::HnswGraphFile(file.object_id),
301                    file.file_size,
302                )
303            }),
304        };
305        vector_files
306            .iter()
307            .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size))
308            .chain(graph_file_object_id)
309    }
310}
311
312impl From<PbVectorIndex> for VectorIndex {
313    fn from(pb: PbVectorIndex) -> Self {
314        Self {
315            dimension: pb.dimension as _,
316            distance_type: pb.distance_type.try_into().unwrap(),
317            inner: pb.variant.unwrap().into(),
318        }
319    }
320}
321
322#[derive(Clone, Debug, PartialEq)]
323pub struct FlatIndexAdd {
324    pub vector_store_info_delta: VectorStoreInfoDelta,
325}
326
327impl From<PbFlatIndexAdd> for FlatIndexAdd {
328    fn from(add: PbFlatIndexAdd) -> Self {
329        Self {
330            vector_store_info_delta: VectorStoreInfoDelta {
331                next_vector_id: add.next_vector_id.try_into().unwrap(),
332                added_vector_files: add
333                    .added_vector_files
334                    .into_iter()
335                    .map(VectorFileInfo::from)
336                    .collect(),
337            },
338        }
339    }
340}
341
342impl From<FlatIndexAdd> for PbFlatIndexAdd {
343    fn from(add: FlatIndexAdd) -> Self {
344        Self {
345            next_vector_id: add
346                .vector_store_info_delta
347                .next_vector_id
348                .try_into()
349                .unwrap(),
350            added_vector_files: add
351                .vector_store_info_delta
352                .added_vector_files
353                .into_iter()
354                .map(PbVectorFileInfo::from)
355                .collect(),
356        }
357    }
358}
359
360#[derive(Clone, Debug, PartialEq)]
361pub struct HnswFlatIndexAdd {
362    pub vector_store_info_delta: VectorStoreInfoDelta,
363    pub graph_file: HnswGraphFileInfo,
364}
365
366impl From<PbHnswFlatIndexAdd> for HnswFlatIndexAdd {
367    fn from(add: PbHnswFlatIndexAdd) -> Self {
368        Self {
369            vector_store_info_delta: VectorStoreInfoDelta {
370                next_vector_id: add.next_vector_id.try_into().unwrap(),
371                added_vector_files: add
372                    .added_vector_files
373                    .into_iter()
374                    .map(VectorFileInfo::from)
375                    .collect(),
376            },
377            graph_file: add.graph_file.unwrap().into(),
378        }
379    }
380}
381
382impl From<HnswFlatIndexAdd> for PbHnswFlatIndexAdd {
383    fn from(add: HnswFlatIndexAdd) -> Self {
384        Self {
385            added_vector_files: add
386                .vector_store_info_delta
387                .added_vector_files
388                .into_iter()
389                .map(PbVectorFileInfo::from)
390                .collect(),
391            next_vector_id: add
392                .vector_store_info_delta
393                .next_vector_id
394                .try_into()
395                .unwrap(),
396            graph_file: Some(add.graph_file.into()),
397        }
398    }
399}
400
401#[derive(Clone, Debug, PartialEq)]
402pub enum VectorIndexAdd {
403    Flat(FlatIndexAdd),
404    HnswFlat(HnswFlatIndexAdd),
405}
406
407impl From<PbVectorIndexAdd> for VectorIndexAdd {
408    fn from(add: PbVectorIndexAdd) -> Self {
409        match add.add.unwrap() {
410            vector_index_add::Add::Flat(flat_add) => Self::Flat(flat_add.into()),
411            vector_index_add::Add::HnswFlat(hnsw_flat_add) => Self::HnswFlat(hnsw_flat_add.into()),
412        }
413    }
414}
415
416impl From<VectorIndexAdd> for PbVectorIndexAdd {
417    fn from(add: VectorIndexAdd) -> Self {
418        match add {
419            VectorIndexAdd::Flat(flat_add) => Self {
420                add: Some(vector_index_add::Add::Flat(flat_add.into())),
421            },
422            VectorIndexAdd::HnswFlat(hnsw_flat_add) => Self {
423                add: Some(vector_index_add::Add::HnswFlat(hnsw_flat_add.into())),
424            },
425        }
426    }
427}
428
429#[derive(Clone, Debug, PartialEq)]
430pub enum VectorIndexDelta {
431    Init(PbVectorIndexInit),
432    Adds(Vec<VectorIndexAdd>),
433}
434
435impl From<PbVectorIndexDelta> for VectorIndexDelta {
436    fn from(delta: PbVectorIndexDelta) -> Self {
437        match delta.delta.unwrap() {
438            vector_index_delta::Delta::Init(init) => Self::Init(init),
439            vector_index_delta::Delta::Adds(adds) => {
440                Self::Adds(adds.adds.into_iter().map(Into::into).collect())
441            }
442        }
443    }
444}
445
446impl VectorIndexDelta {
447    pub fn newly_added_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
448        // DO NOT REMOVE THIS LINE
449        // This is to ensure that when adding new variant to `HummockObjectId`,
450        // the compiler will warn us if we forget to handle it here.
451        match HummockObjectId::Sstable(0.into()) {
452            HummockObjectId::Sstable(_) => {}
453            HummockObjectId::VectorFile(_) => {}
454            HummockObjectId::HnswGraphFile(_) => {}
455        };
456        match self {
457            VectorIndexDelta::Init(_) => None,
458            VectorIndexDelta::Adds(adds) => Some(adds.iter().flat_map(|add| {
459                let vector_store_delta = match add {
460                    VectorIndexAdd::Flat(add) => &add.vector_store_info_delta,
461                    VectorIndexAdd::HnswFlat(add) => &add.vector_store_info_delta,
462                };
463                let added_graph_file = match add {
464                    VectorIndexAdd::Flat(_) => None,
465                    VectorIndexAdd::HnswFlat(add) => Some((
466                        HummockObjectId::HnswGraphFile(add.graph_file.object_id),
467                        add.graph_file.file_size,
468                    )),
469                };
470                vector_store_delta
471                    .added_vector_files
472                    .iter()
473                    .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size))
474                    .chain(added_graph_file)
475            })),
476        }
477        .into_iter()
478        .flatten()
479    }
480}
481
482impl From<VectorIndexDelta> for PbVectorIndexDelta {
483    fn from(delta: VectorIndexDelta) -> Self {
484        match delta {
485            VectorIndexDelta::Init(init) => Self {
486                delta: Some(vector_index_delta::Delta::Init(init)),
487            },
488            VectorIndexDelta::Adds(adds) => Self {
489                delta: Some(vector_index_delta::Delta::Adds(
490                    vector_index_delta::VectorIndexAdds {
491                        adds: adds.into_iter().map(Into::into).collect(),
492                    },
493                )),
494            },
495        }
496    }
497}
498
499impl From<VectorIndex> for PbVectorIndex {
500    fn from(index: VectorIndex) -> Self {
501        Self {
502            dimension: index.dimension as _,
503            distance_type: index.distance_type as _,
504            variant: Some(index.inner.into()),
505        }
506    }
507}
508
509fn init_vector_index(init: &PbVectorIndexInit) -> VectorIndex {
510    let init_info = init.info.as_ref().unwrap();
511    let inner = match init_info.config.as_ref().unwrap() {
512        Config::Flat(config) => VectorIndexImpl::Flat(FlatIndex::new(config)),
513        Config::HnswFlat(config) => VectorIndexImpl::HnswFlat(HnswFlatIndex::new(config)),
514    };
515    VectorIndex {
516        dimension: init_info.dimension as _,
517        distance_type: init_info.distance_type.try_into().unwrap(),
518        inner,
519    }
520}
521
522fn apply_vector_index_add(inner: &mut VectorIndexImpl, add: &VectorIndexAdd) {
523    match inner {
524        VectorIndexImpl::Flat(flat_index) => {
525            let VectorIndexAdd::Flat(add) = add else {
526                panic!("expect FlatIndexAdd but got {:?}", flat_index);
527            };
528            flat_index.apply_flat_index_add(add);
529        }
530        VectorIndexImpl::HnswFlat(hnsw_flat_index) => {
531            let VectorIndexAdd::HnswFlat(add) = add else {
532                panic!("expect HnswFlatIndexAdd but got {:?}", hnsw_flat_index);
533            };
534            hnsw_flat_index.apply_hnsw_flat_index_add(add);
535        }
536    }
537}
538
539pub fn apply_vector_index_delta(
540    vector_index: &mut HashMap<TableId, VectorIndex>,
541    vector_index_delta: &HashMap<TableId, VectorIndexDelta>,
542    removed_table_ids: &HashSet<TableId>,
543) {
544    for (table_id, vector_index_delta) in vector_index_delta {
545        match vector_index_delta {
546            VectorIndexDelta::Init(init) => {
547                vector_index
548                    .try_insert(*table_id, init_vector_index(init))
549                    .unwrap();
550            }
551            VectorIndexDelta::Adds(adds) => {
552                let inner = &mut vector_index.get_mut(table_id).unwrap().inner;
553                for add in adds {
554                    apply_vector_index_add(inner, add);
555                }
556            }
557        }
558    }
559
560    // Remove the vector index for the tables that are removed
561    vector_index.retain(|table_id, _| !removed_table_ids.contains(table_id));
562}