risingwave_hummock_sdk/
vector_index.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use risingwave_common::catalog::TableId;
18use risingwave_pb::hummock::vector_index::PbVariant;
19use risingwave_pb::hummock::vector_index_delta::vector_index_init::Config;
20use risingwave_pb::hummock::vector_index_delta::{
21    PbVectorIndexAdd, PbVectorIndexInit, vector_index_add,
22};
23use risingwave_pb::hummock::{
24    PbDistanceType, PbFlatIndex, PbFlatIndexAdd, PbFlatIndexConfig, PbVectorFileInfo,
25    PbVectorIndex, PbVectorIndexDelta, vector_index_delta,
26};
27
28use crate::{HummockObjectId, HummockVectorFileId};
29
30#[derive(Clone, Debug, PartialEq)]
31pub struct VectorFileInfo {
32    pub object_id: HummockVectorFileId,
33    pub file_size: u64,
34}
35
36impl From<PbVectorFileInfo> for VectorFileInfo {
37    fn from(pb: PbVectorFileInfo) -> Self {
38        Self {
39            object_id: pb.object_id.into(),
40            file_size: pb.file_size,
41        }
42    }
43}
44
45impl From<VectorFileInfo> for PbVectorFileInfo {
46    fn from(info: VectorFileInfo) -> Self {
47        Self {
48            object_id: info.object_id.inner(),
49            file_size: info.file_size,
50        }
51    }
52}
53
54#[derive(Debug, Clone, PartialEq)]
55pub struct FlatIndex {
56    pub config: PbFlatIndexConfig,
57    pub vector_files: Vec<VectorFileInfo>,
58}
59
60impl FlatIndex {
61    fn new(config: &PbFlatIndexConfig) -> FlatIndex {
62        FlatIndex {
63            config: *config,
64            vector_files: vec![],
65        }
66    }
67
68    fn apply_flat_index_add(&mut self, add: &FlatIndexAdd) {
69        self.vector_files
70            .extend(add.added_vector_files.iter().cloned());
71    }
72}
73
74impl From<PbFlatIndex> for FlatIndex {
75    fn from(pb: PbFlatIndex) -> Self {
76        Self {
77            config: pb.config.unwrap(),
78            vector_files: pb
79                .vector_files
80                .into_iter()
81                .map(VectorFileInfo::from)
82                .collect(),
83        }
84    }
85}
86impl From<FlatIndex> for PbFlatIndex {
87    fn from(index: FlatIndex) -> Self {
88        Self {
89            config: Some(index.config),
90            vector_files: index
91                .vector_files
92                .into_iter()
93                .map(PbVectorFileInfo::from)
94                .collect(),
95        }
96    }
97}
98
99#[derive(Debug, Clone, PartialEq)]
100pub enum VectorIndexImpl {
101    Flat(FlatIndex),
102}
103
104impl From<PbVariant> for VectorIndexImpl {
105    fn from(variant: PbVariant) -> Self {
106        match variant {
107            PbVariant::Flat(flat_index) => Self::Flat(flat_index.into()),
108        }
109    }
110}
111
112impl From<VectorIndexImpl> for PbVariant {
113    fn from(index: VectorIndexImpl) -> Self {
114        match index {
115            VectorIndexImpl::Flat(flat_index) => PbVariant::Flat(flat_index.into()),
116        }
117    }
118}
119
120#[derive(Debug, Clone, PartialEq)]
121pub struct VectorIndex {
122    pub dimension: usize,
123    pub distance_type: PbDistanceType,
124    pub inner: VectorIndexImpl,
125}
126
127impl VectorIndex {
128    pub fn get_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
129        // DO NOT REMOVE THIS LINE
130        // This is to ensure that when adding new variant to `HummockObjectId`,
131        // the compiler will warn us if we forget to handle it here.
132        match HummockObjectId::Sstable(0.into()) {
133            HummockObjectId::Sstable(_) => {}
134            HummockObjectId::VectorFile(_) => {}
135        };
136        match &self.inner {
137            VectorIndexImpl::Flat(flat) => flat
138                .vector_files
139                .iter()
140                .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size)),
141        }
142    }
143}
144
145impl From<PbVectorIndex> for VectorIndex {
146    fn from(pb: PbVectorIndex) -> Self {
147        Self {
148            dimension: pb.dimension as _,
149            distance_type: pb.distance_type.try_into().unwrap(),
150            inner: pb.variant.unwrap().into(),
151        }
152    }
153}
154
155#[derive(Clone, Debug, PartialEq)]
156pub struct FlatIndexAdd {
157    pub added_vector_files: Vec<VectorFileInfo>,
158}
159
160impl From<PbFlatIndexAdd> for FlatIndexAdd {
161    fn from(add: PbFlatIndexAdd) -> Self {
162        Self {
163            added_vector_files: add
164                .added_vector_files
165                .into_iter()
166                .map(VectorFileInfo::from)
167                .collect(),
168        }
169    }
170}
171
172impl From<FlatIndexAdd> for PbFlatIndexAdd {
173    fn from(add: FlatIndexAdd) -> Self {
174        Self {
175            added_vector_files: add
176                .added_vector_files
177                .into_iter()
178                .map(PbVectorFileInfo::from)
179                .collect(),
180        }
181    }
182}
183
184#[derive(Clone, Debug, PartialEq)]
185pub enum VectorIndexAdd {
186    Flat(FlatIndexAdd),
187}
188
189impl From<PbVectorIndexAdd> for VectorIndexAdd {
190    fn from(add: PbVectorIndexAdd) -> Self {
191        match add.add.unwrap() {
192            vector_index_add::Add::Flat(flat_add) => Self::Flat(flat_add.into()),
193        }
194    }
195}
196
197impl From<VectorIndexAdd> for PbVectorIndexAdd {
198    fn from(add: VectorIndexAdd) -> Self {
199        match add {
200            VectorIndexAdd::Flat(flat_add) => Self {
201                add: Some(vector_index_add::Add::Flat(flat_add.into())),
202            },
203        }
204    }
205}
206
207#[derive(Clone, Debug, PartialEq)]
208pub enum VectorIndexDelta {
209    Init(PbVectorIndexInit),
210    Adds(Vec<VectorIndexAdd>),
211}
212
213impl From<PbVectorIndexDelta> for VectorIndexDelta {
214    fn from(delta: PbVectorIndexDelta) -> Self {
215        match delta.delta.unwrap() {
216            vector_index_delta::Delta::Init(init) => Self::Init(init),
217            vector_index_delta::Delta::Adds(adds) => {
218                Self::Adds(adds.adds.into_iter().map(Into::into).collect())
219            }
220        }
221    }
222}
223
224impl VectorIndexDelta {
225    pub fn newly_added_objects(&self) -> impl Iterator<Item = (HummockObjectId, u64)> + '_ {
226        match self {
227            VectorIndexDelta::Init(_) => None,
228            VectorIndexDelta::Adds(adds) => Some(adds.iter().flat_map(|add| {
229                match add {
230                    VectorIndexAdd::Flat(add) => add
231                        .added_vector_files
232                        .iter()
233                        .map(|file| (HummockObjectId::VectorFile(file.object_id), file.file_size)),
234                }
235            })),
236        }
237        .into_iter()
238        .flatten()
239    }
240}
241
242impl From<VectorIndexDelta> for PbVectorIndexDelta {
243    fn from(delta: VectorIndexDelta) -> Self {
244        match delta {
245            VectorIndexDelta::Init(init) => Self {
246                delta: Some(vector_index_delta::Delta::Init(init)),
247            },
248            VectorIndexDelta::Adds(adds) => Self {
249                delta: Some(vector_index_delta::Delta::Adds(
250                    vector_index_delta::VectorIndexAdds {
251                        adds: adds.into_iter().map(Into::into).collect(),
252                    },
253                )),
254            },
255        }
256    }
257}
258
259impl From<VectorIndex> for PbVectorIndex {
260    fn from(index: VectorIndex) -> Self {
261        Self {
262            dimension: index.dimension as _,
263            distance_type: index.distance_type as _,
264            variant: Some(index.inner.into()),
265        }
266    }
267}
268
269fn init_vector_index(init: &PbVectorIndexInit) -> VectorIndex {
270    let inner = match init.config.as_ref().unwrap() {
271        Config::Flat(config) => VectorIndexImpl::Flat(FlatIndex::new(config)),
272    };
273    VectorIndex {
274        dimension: init.dimension as _,
275        distance_type: init.distance_type.try_into().unwrap(),
276        inner,
277    }
278}
279
280fn apply_vector_index_add(inner: &mut VectorIndexImpl, add: &VectorIndexAdd) {
281    match inner {
282        VectorIndexImpl::Flat(flat_index) => {
283            #[expect(irrefutable_let_patterns)]
284            let VectorIndexAdd::Flat(add) = add else {
285                panic!("expect FlatIndexAdd but got {:?}", flat_index);
286            };
287            flat_index.apply_flat_index_add(add);
288        }
289    }
290}
291
292pub fn apply_vector_index_delta(
293    vector_index: &mut HashMap<TableId, VectorIndex>,
294    vector_index_delta: &HashMap<TableId, VectorIndexDelta>,
295    removed_table_ids: &HashSet<TableId>,
296) {
297    for (table_id, vector_index_delta) in vector_index_delta {
298        match vector_index_delta {
299            VectorIndexDelta::Init(init) => {
300                vector_index
301                    .try_insert(*table_id, init_vector_index(init))
302                    .unwrap();
303            }
304            VectorIndexDelta::Adds(adds) => {
305                let inner = &mut vector_index.get_mut(table_id).unwrap().inner;
306                for add in adds {
307                    apply_vector_index_add(inner, add);
308                }
309            }
310        }
311    }
312
313    // Remove the vector index for the tables that are removed
314    vector_index.retain(|table_id, _| !removed_table_ids.contains(table_id));
315}