risingwave_hummock_sdk/
vector_index.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use risingwave_common::catalog::TableId;
18use risingwave_pb::hummock::vector_index::PbVariant;
19use risingwave_pb::hummock::vector_index_delta::vector_index_init::Config;
20use risingwave_pb::hummock::vector_index_delta::{
21    PbVectorIndexAdd, PbVectorIndexInit, vector_index_add,
22};
23use risingwave_pb::hummock::{
24    PbDistanceType, PbFlatIndex, PbFlatIndexAdd, PbFlatIndexConfig, PbVectorFileInfo,
25    PbVectorIndex, PbVectorIndexDelta, vector_index_delta,
26};
27
28use crate::{HummockObjectId, HummockVectorFileId};
29
30#[derive(Clone, Debug, PartialEq)]
31pub struct VectorFileInfo {
32    pub object_id: HummockVectorFileId,
33    pub file_size: u64,
34    pub start_vector_id: usize,
35    pub vector_count: usize,
36    pub meta_offset: usize,
37}
38
39impl From<PbVectorFileInfo> for VectorFileInfo {
40    fn from(pb: PbVectorFileInfo) -> Self {
41        Self {
42            object_id: pb.object_id.into(),
43            file_size: pb.file_size,
44            start_vector_id: pb.start_vector_id.try_into().unwrap(),
45            vector_count: pb.vector_count.try_into().unwrap(),
46            meta_offset: pb.meta_offset.try_into().unwrap(),
47        }
48    }
49}
50
51impl From<VectorFileInfo> for PbVectorFileInfo {
52    fn from(info: VectorFileInfo) -> Self {
53        Self {
54            object_id: info.object_id.inner(),
55            file_size: info.file_size,
56            start_vector_id: info.start_vector_id.try_into().unwrap(),
57            vector_count: info.vector_count.try_into().unwrap(),
58            meta_offset: info.meta_offset.try_into().unwrap(),
59        }
60    }
61}
62
63#[derive(Debug, Clone, PartialEq)]
64pub struct VectorStoreInfo {
65    pub next_vector_id: usize,
66    pub vector_files: Vec<VectorFileInfo>,
67}
68
69#[derive(Clone, Debug, PartialEq)]
70pub struct VectorStoreInfoDelta {
71    pub next_vector_id: usize,
72    pub added_vector_files: Vec<VectorFileInfo>,
73}
74
75impl VectorStoreInfo {
76    fn empty() -> Self {
77        Self {
78            next_vector_id: 0,
79            vector_files: vec![],
80        }
81    }
82
83    fn apply_vector_store_delta(&mut self, delta: &VectorStoreInfoDelta) {
84        for new_vector_file in &delta.added_vector_files {
85            if let Some(latest_vector_file) = self.vector_files.last() {
86                assert!(
87                    new_vector_file.start_vector_id
88                        >= latest_vector_file.start_vector_id + latest_vector_file.vector_count,
89                    "new vector file's start vector id {} should be greater than the last vector file's start vector id {} + vector count {}",
90                    new_vector_file.start_vector_id,
91                    latest_vector_file.start_vector_id,
92                    latest_vector_file.vector_count
93                );
94            }
95            self.vector_files.push(new_vector_file.clone());
96        }
97        self.next_vector_id = delta.next_vector_id;
98        if let Some(latest_vector_file) = self.vector_files.last() {
99            assert!(
100                latest_vector_file.start_vector_id + latest_vector_file.vector_count
101                    <= self.next_vector_id,
102                "next_vector_id {} should be greater than the last vector file's start vector id {} + vector count {}",
103                self.next_vector_id,
104                latest_vector_file.start_vector_id,
105                latest_vector_file.vector_count
106            );
107        }
108    }
109}
110
111#[derive(Debug, Clone, PartialEq)]
112pub struct FlatIndex {
113    pub config: PbFlatIndexConfig,
114    pub vector_store_info: VectorStoreInfo,
115}
116
117impl FlatIndex {
118    fn new(config: &PbFlatIndexConfig) -> FlatIndex {
119        FlatIndex {
120            config: *config,
121            vector_store_info: VectorStoreInfo::empty(),
122        }
123    }
124
125    fn apply_flat_index_add(&mut self, add: &FlatIndexAdd) {
126        self.vector_store_info
127            .apply_vector_store_delta(&add.vector_store_info_delta);
128    }
129}
130
131impl From<PbFlatIndex> for FlatIndex {
132    fn from(pb: PbFlatIndex) -> Self {
133        Self {
134            config: pb.config.unwrap(),
135            vector_store_info: VectorStoreInfo {
136                next_vector_id: pb.next_vector_id.try_into().unwrap(),
137                vector_files: pb
138                    .vector_files
139                    .into_iter()
140                    .map(VectorFileInfo::from)
141                    .collect(),
142            },
143        }
144    }
145}
146impl From<FlatIndex> for PbFlatIndex {
147    fn from(index: FlatIndex) -> Self {
148        Self {
149            config: Some(index.config),
150            next_vector_id: index.vector_store_info.next_vector_id.try_into().unwrap(),
151            vector_files: index
152                .vector_store_info
153                .vector_files
154                .into_iter()
155                .map(PbVectorFileInfo::from)
156                .collect(),
157        }
158    }
159}
160
161#[derive(Debug, Clone, PartialEq)]
162pub enum VectorIndexImpl {
163    Flat(FlatIndex),
164}
165
166impl From<PbVariant> for VectorIndexImpl {
167    fn from(variant: PbVariant) -> Self {
168        match variant {
169            PbVariant::Flat(flat_index) => Self::Flat(flat_index.into()),
170        }
171    }
172}
173
174impl From<VectorIndexImpl> for PbVariant {
175    fn from(index: VectorIndexImpl) -> Self {
176        match index {
177            VectorIndexImpl::Flat(flat_index) => PbVariant::Flat(flat_index.into()),
178        }
179    }
180}
181
182#[derive(Debug, Clone, PartialEq)]
183pub struct VectorIndex {
184    pub dimension: usize,
185    pub distance_type: PbDistanceType,
186    pub inner: VectorIndexImpl,
187}
188
189impl VectorIndex {
190    pub fn get_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
191        // DO NOT REMOVE THIS LINE
192        // This is to ensure that when adding new variant to `HummockObjectId`,
193        // the compiler will warn us if we forget to handle it here.
194        match HummockObjectId::Sstable(0.into()) {
195            HummockObjectId::Sstable(_) => {}
196            HummockObjectId::VectorFile(_) => {}
197        };
198        match &self.inner {
199            VectorIndexImpl::Flat(flat) => flat
200                .vector_store_info
201                .vector_files
202                .iter()
203                .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size)),
204        }
205    }
206}
207
208impl From<PbVectorIndex> for VectorIndex {
209    fn from(pb: PbVectorIndex) -> Self {
210        Self {
211            dimension: pb.dimension as _,
212            distance_type: pb.distance_type.try_into().unwrap(),
213            inner: pb.variant.unwrap().into(),
214        }
215    }
216}
217
218#[derive(Clone, Debug, PartialEq)]
219pub struct FlatIndexAdd {
220    pub vector_store_info_delta: VectorStoreInfoDelta,
221}
222
223impl From<PbFlatIndexAdd> for FlatIndexAdd {
224    fn from(add: PbFlatIndexAdd) -> Self {
225        Self {
226            vector_store_info_delta: VectorStoreInfoDelta {
227                next_vector_id: add.next_vector_id.try_into().unwrap(),
228                added_vector_files: add
229                    .added_vector_files
230                    .into_iter()
231                    .map(VectorFileInfo::from)
232                    .collect(),
233            },
234        }
235    }
236}
237
238impl From<FlatIndexAdd> for PbFlatIndexAdd {
239    fn from(add: FlatIndexAdd) -> Self {
240        Self {
241            next_vector_id: add
242                .vector_store_info_delta
243                .next_vector_id
244                .try_into()
245                .unwrap(),
246            added_vector_files: add
247                .vector_store_info_delta
248                .added_vector_files
249                .into_iter()
250                .map(PbVectorFileInfo::from)
251                .collect(),
252        }
253    }
254}
255
256#[derive(Clone, Debug, PartialEq)]
257pub enum VectorIndexAdd {
258    Flat(FlatIndexAdd),
259}
260
261impl From<PbVectorIndexAdd> for VectorIndexAdd {
262    fn from(add: PbVectorIndexAdd) -> Self {
263        match add.add.unwrap() {
264            vector_index_add::Add::Flat(flat_add) => Self::Flat(flat_add.into()),
265        }
266    }
267}
268
269impl From<VectorIndexAdd> for PbVectorIndexAdd {
270    fn from(add: VectorIndexAdd) -> Self {
271        match add {
272            VectorIndexAdd::Flat(flat_add) => Self {
273                add: Some(vector_index_add::Add::Flat(flat_add.into())),
274            },
275        }
276    }
277}
278
279#[derive(Clone, Debug, PartialEq)]
280pub enum VectorIndexDelta {
281    Init(PbVectorIndexInit),
282    Adds(Vec<VectorIndexAdd>),
283}
284
285impl From<PbVectorIndexDelta> for VectorIndexDelta {
286    fn from(delta: PbVectorIndexDelta) -> Self {
287        match delta.delta.unwrap() {
288            vector_index_delta::Delta::Init(init) => Self::Init(init),
289            vector_index_delta::Delta::Adds(adds) => {
290                Self::Adds(adds.adds.into_iter().map(Into::into).collect())
291            }
292        }
293    }
294}
295
296impl VectorIndexDelta {
297    pub fn newly_added_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
298        match self {
299            VectorIndexDelta::Init(_) => None,
300            VectorIndexDelta::Adds(adds) => Some(adds.iter().flat_map(|add| {
301                match add {
302                    VectorIndexAdd::Flat(add) => add
303                        .vector_store_info_delta
304                        .added_vector_files
305                        .iter()
306                        .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size)),
307                }
308            })),
309        }
310        .into_iter()
311        .flatten()
312    }
313}
314
315impl From<VectorIndexDelta> for PbVectorIndexDelta {
316    fn from(delta: VectorIndexDelta) -> Self {
317        match delta {
318            VectorIndexDelta::Init(init) => Self {
319                delta: Some(vector_index_delta::Delta::Init(init)),
320            },
321            VectorIndexDelta::Adds(adds) => Self {
322                delta: Some(vector_index_delta::Delta::Adds(
323                    vector_index_delta::VectorIndexAdds {
324                        adds: adds.into_iter().map(Into::into).collect(),
325                    },
326                )),
327            },
328        }
329    }
330}
331
332impl From<VectorIndex> for PbVectorIndex {
333    fn from(index: VectorIndex) -> Self {
334        Self {
335            dimension: index.dimension as _,
336            distance_type: index.distance_type as _,
337            variant: Some(index.inner.into()),
338        }
339    }
340}
341
342fn init_vector_index(init: &PbVectorIndexInit) -> VectorIndex {
343    let inner = match init.config.as_ref().unwrap() {
344        Config::Flat(config) => VectorIndexImpl::Flat(FlatIndex::new(config)),
345    };
346    VectorIndex {
347        dimension: init.dimension as _,
348        distance_type: init.distance_type.try_into().unwrap(),
349        inner,
350    }
351}
352
353fn apply_vector_index_add(inner: &mut VectorIndexImpl, add: &VectorIndexAdd) {
354    match inner {
355        VectorIndexImpl::Flat(flat_index) => {
356            #[expect(irrefutable_let_patterns)]
357            let VectorIndexAdd::Flat(add) = add else {
358                panic!("expect FlatIndexAdd but got {:?}", flat_index);
359            };
360            flat_index.apply_flat_index_add(add);
361        }
362    }
363}
364
365pub fn apply_vector_index_delta(
366    vector_index: &mut HashMap<TableId, VectorIndex>,
367    vector_index_delta: &HashMap<TableId, VectorIndexDelta>,
368    removed_table_ids: &HashSet<TableId>,
369) {
370    for (table_id, vector_index_delta) in vector_index_delta {
371        match vector_index_delta {
372            VectorIndexDelta::Init(init) => {
373                vector_index
374                    .try_insert(*table_id, init_vector_index(init))
375                    .unwrap();
376            }
377            VectorIndexDelta::Adds(adds) => {
378                let inner = &mut vector_index.get_mut(table_id).unwrap().inner;
379                for add in adds {
380                    apply_vector_index_add(inner, add);
381                }
382            }
383        }
384    }
385
386    // Remove the vector index for the tables that are removed
387    vector_index.retain(|table_id, _| !removed_table_ids.contains(table_id));
388}