risingwave_hummock_sdk/
vector_index.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use risingwave_common::catalog::TableId;
18use risingwave_pb::catalog::PbFlatIndexConfig;
19use risingwave_pb::catalog::vector_index_info::Config;
20use risingwave_pb::common::PbDistanceType;
21use risingwave_pb::hummock::vector_index::PbVariant;
22use risingwave_pb::hummock::vector_index_delta::{
23    PbVectorIndexAdd, PbVectorIndexInit, vector_index_add,
24};
25use risingwave_pb::hummock::{
26    PbFlatIndex, PbFlatIndexAdd, PbVectorFileInfo, PbVectorIndex, PbVectorIndexDelta,
27    vector_index_delta,
28};
29
30use crate::{HummockObjectId, HummockVectorFileId};
31
32#[derive(Clone, Debug, PartialEq)]
33pub struct VectorFileInfo {
34    pub object_id: HummockVectorFileId,
35    pub file_size: u64,
36    pub start_vector_id: usize,
37    pub vector_count: usize,
38    pub meta_offset: usize,
39}
40
41impl From<PbVectorFileInfo> for VectorFileInfo {
42    fn from(pb: PbVectorFileInfo) -> Self {
43        Self {
44            object_id: pb.object_id.into(),
45            file_size: pb.file_size,
46            start_vector_id: pb.start_vector_id.try_into().unwrap(),
47            vector_count: pb.vector_count.try_into().unwrap(),
48            meta_offset: pb.meta_offset.try_into().unwrap(),
49        }
50    }
51}
52
53impl From<VectorFileInfo> for PbVectorFileInfo {
54    fn from(info: VectorFileInfo) -> Self {
55        Self {
56            object_id: info.object_id.inner(),
57            file_size: info.file_size,
58            start_vector_id: info.start_vector_id.try_into().unwrap(),
59            vector_count: info.vector_count.try_into().unwrap(),
60            meta_offset: info.meta_offset.try_into().unwrap(),
61        }
62    }
63}
64
65#[derive(Debug, Clone, PartialEq)]
66pub struct VectorStoreInfo {
67    pub next_vector_id: usize,
68    pub vector_files: Vec<VectorFileInfo>,
69}
70
71#[derive(Clone, Debug, PartialEq)]
72pub struct VectorStoreInfoDelta {
73    pub next_vector_id: usize,
74    pub added_vector_files: Vec<VectorFileInfo>,
75}
76
77impl VectorStoreInfo {
78    fn empty() -> Self {
79        Self {
80            next_vector_id: 0,
81            vector_files: vec![],
82        }
83    }
84
85    fn apply_vector_store_delta(&mut self, delta: &VectorStoreInfoDelta) {
86        for new_vector_file in &delta.added_vector_files {
87            if let Some(latest_vector_file) = self.vector_files.last() {
88                assert!(
89                    new_vector_file.start_vector_id
90                        >= latest_vector_file.start_vector_id + latest_vector_file.vector_count,
91                    "new vector file's start vector id {} should be greater than the last vector file's start vector id {} + vector count {}",
92                    new_vector_file.start_vector_id,
93                    latest_vector_file.start_vector_id,
94                    latest_vector_file.vector_count
95                );
96            }
97            self.vector_files.push(new_vector_file.clone());
98        }
99        self.next_vector_id = delta.next_vector_id;
100        if let Some(latest_vector_file) = self.vector_files.last() {
101            assert!(
102                latest_vector_file.start_vector_id + latest_vector_file.vector_count
103                    <= self.next_vector_id,
104                "next_vector_id {} should be greater than the last vector file's start vector id {} + vector count {}",
105                self.next_vector_id,
106                latest_vector_file.start_vector_id,
107                latest_vector_file.vector_count
108            );
109        }
110    }
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct FlatIndex {
115    pub config: PbFlatIndexConfig,
116    pub vector_store_info: VectorStoreInfo,
117}
118
119impl FlatIndex {
120    fn new(config: &PbFlatIndexConfig) -> FlatIndex {
121        FlatIndex {
122            config: *config,
123            vector_store_info: VectorStoreInfo::empty(),
124        }
125    }
126
127    fn apply_flat_index_add(&mut self, add: &FlatIndexAdd) {
128        self.vector_store_info
129            .apply_vector_store_delta(&add.vector_store_info_delta);
130    }
131}
132
133impl From<PbFlatIndex> for FlatIndex {
134    fn from(pb: PbFlatIndex) -> Self {
135        Self {
136            config: pb.config.unwrap(),
137            vector_store_info: VectorStoreInfo {
138                next_vector_id: pb.next_vector_id.try_into().unwrap(),
139                vector_files: pb
140                    .vector_files
141                    .into_iter()
142                    .map(VectorFileInfo::from)
143                    .collect(),
144            },
145        }
146    }
147}
148impl From<FlatIndex> for PbFlatIndex {
149    fn from(index: FlatIndex) -> Self {
150        Self {
151            config: Some(index.config),
152            next_vector_id: index.vector_store_info.next_vector_id.try_into().unwrap(),
153            vector_files: index
154                .vector_store_info
155                .vector_files
156                .into_iter()
157                .map(PbVectorFileInfo::from)
158                .collect(),
159        }
160    }
161}
162
163#[derive(Debug, Clone, PartialEq)]
164pub enum VectorIndexImpl {
165    Flat(FlatIndex),
166}
167
168impl From<PbVariant> for VectorIndexImpl {
169    fn from(variant: PbVariant) -> Self {
170        match variant {
171            PbVariant::Flat(flat_index) => Self::Flat(flat_index.into()),
172        }
173    }
174}
175
176impl From<VectorIndexImpl> for PbVariant {
177    fn from(index: VectorIndexImpl) -> Self {
178        match index {
179            VectorIndexImpl::Flat(flat_index) => PbVariant::Flat(flat_index.into()),
180        }
181    }
182}
183
184#[derive(Debug, Clone, PartialEq)]
185pub struct VectorIndex {
186    pub dimension: usize,
187    pub distance_type: PbDistanceType,
188    pub inner: VectorIndexImpl,
189}
190
191impl VectorIndex {
192    pub fn get_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
193        // DO NOT REMOVE THIS LINE
194        // This is to ensure that when adding new variant to `HummockObjectId`,
195        // the compiler will warn us if we forget to handle it here.
196        match HummockObjectId::Sstable(0.into()) {
197            HummockObjectId::Sstable(_) => {}
198            HummockObjectId::VectorFile(_) => {}
199        };
200        match &self.inner {
201            VectorIndexImpl::Flat(flat) => flat
202                .vector_store_info
203                .vector_files
204                .iter()
205                .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size)),
206        }
207    }
208}
209
210impl From<PbVectorIndex> for VectorIndex {
211    fn from(pb: PbVectorIndex) -> Self {
212        Self {
213            dimension: pb.dimension as _,
214            distance_type: pb.distance_type.try_into().unwrap(),
215            inner: pb.variant.unwrap().into(),
216        }
217    }
218}
219
220#[derive(Clone, Debug, PartialEq)]
221pub struct FlatIndexAdd {
222    pub vector_store_info_delta: VectorStoreInfoDelta,
223}
224
225impl From<PbFlatIndexAdd> for FlatIndexAdd {
226    fn from(add: PbFlatIndexAdd) -> Self {
227        Self {
228            vector_store_info_delta: VectorStoreInfoDelta {
229                next_vector_id: add.next_vector_id.try_into().unwrap(),
230                added_vector_files: add
231                    .added_vector_files
232                    .into_iter()
233                    .map(VectorFileInfo::from)
234                    .collect(),
235            },
236        }
237    }
238}
239
240impl From<FlatIndexAdd> for PbFlatIndexAdd {
241    fn from(add: FlatIndexAdd) -> Self {
242        Self {
243            next_vector_id: add
244                .vector_store_info_delta
245                .next_vector_id
246                .try_into()
247                .unwrap(),
248            added_vector_files: add
249                .vector_store_info_delta
250                .added_vector_files
251                .into_iter()
252                .map(PbVectorFileInfo::from)
253                .collect(),
254        }
255    }
256}
257
258#[derive(Clone, Debug, PartialEq)]
259pub enum VectorIndexAdd {
260    Flat(FlatIndexAdd),
261}
262
263impl From<PbVectorIndexAdd> for VectorIndexAdd {
264    fn from(add: PbVectorIndexAdd) -> Self {
265        match add.add.unwrap() {
266            vector_index_add::Add::Flat(flat_add) => Self::Flat(flat_add.into()),
267        }
268    }
269}
270
271impl From<VectorIndexAdd> for PbVectorIndexAdd {
272    fn from(add: VectorIndexAdd) -> Self {
273        match add {
274            VectorIndexAdd::Flat(flat_add) => Self {
275                add: Some(vector_index_add::Add::Flat(flat_add.into())),
276            },
277        }
278    }
279}
280
281#[derive(Clone, Debug, PartialEq)]
282pub enum VectorIndexDelta {
283    Init(PbVectorIndexInit),
284    Adds(Vec<VectorIndexAdd>),
285}
286
287impl From<PbVectorIndexDelta> for VectorIndexDelta {
288    fn from(delta: PbVectorIndexDelta) -> Self {
289        match delta.delta.unwrap() {
290            vector_index_delta::Delta::Init(init) => Self::Init(init),
291            vector_index_delta::Delta::Adds(adds) => {
292                Self::Adds(adds.adds.into_iter().map(Into::into).collect())
293            }
294        }
295    }
296}
297
298impl VectorIndexDelta {
299    pub fn newly_added_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
300        match self {
301            VectorIndexDelta::Init(_) => None,
302            VectorIndexDelta::Adds(adds) => Some(adds.iter().flat_map(|add| {
303                match add {
304                    VectorIndexAdd::Flat(add) => add
305                        .vector_store_info_delta
306                        .added_vector_files
307                        .iter()
308                        .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size)),
309                }
310            })),
311        }
312        .into_iter()
313        .flatten()
314    }
315}
316
317impl From<VectorIndexDelta> for PbVectorIndexDelta {
318    fn from(delta: VectorIndexDelta) -> Self {
319        match delta {
320            VectorIndexDelta::Init(init) => Self {
321                delta: Some(vector_index_delta::Delta::Init(init)),
322            },
323            VectorIndexDelta::Adds(adds) => Self {
324                delta: Some(vector_index_delta::Delta::Adds(
325                    vector_index_delta::VectorIndexAdds {
326                        adds: adds.into_iter().map(Into::into).collect(),
327                    },
328                )),
329            },
330        }
331    }
332}
333
334impl From<VectorIndex> for PbVectorIndex {
335    fn from(index: VectorIndex) -> Self {
336        Self {
337            dimension: index.dimension as _,
338            distance_type: index.distance_type as _,
339            variant: Some(index.inner.into()),
340        }
341    }
342}
343
344fn init_vector_index(init: &PbVectorIndexInit) -> VectorIndex {
345    let init_info = init.info.as_ref().unwrap();
346    let inner = match init_info.config.as_ref().unwrap() {
347        Config::Flat(config) => VectorIndexImpl::Flat(FlatIndex::new(config)),
348    };
349    VectorIndex {
350        dimension: init_info.dimension as _,
351        distance_type: init_info.distance_type.try_into().unwrap(),
352        inner,
353    }
354}
355
356fn apply_vector_index_add(inner: &mut VectorIndexImpl, add: &VectorIndexAdd) {
357    match inner {
358        VectorIndexImpl::Flat(flat_index) => {
359            #[expect(irrefutable_let_patterns)]
360            let VectorIndexAdd::Flat(add) = add else {
361                panic!("expect FlatIndexAdd but got {:?}", flat_index);
362            };
363            flat_index.apply_flat_index_add(add);
364        }
365    }
366}
367
368pub fn apply_vector_index_delta(
369    vector_index: &mut HashMap<TableId, VectorIndex>,
370    vector_index_delta: &HashMap<TableId, VectorIndexDelta>,
371    removed_table_ids: &HashSet<TableId>,
372) {
373    for (table_id, vector_index_delta) in vector_index_delta {
374        match vector_index_delta {
375            VectorIndexDelta::Init(init) => {
376                vector_index
377                    .try_insert(*table_id, init_vector_index(init))
378                    .unwrap();
379            }
380            VectorIndexDelta::Adds(adds) => {
381                let inner = &mut vector_index.get_mut(table_id).unwrap().inner;
382                for add in adds {
383                    apply_vector_index_add(inner, add);
384                }
385            }
386        }
387    }
388
389    // Remove the vector index for the tables that are removed
390    vector_index.retain(|table_id, _| !removed_table_ids.contains(table_id));
391}