risingwave_storage/table/batch_table/
vector_index_reader.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::array::{
19    Array, ArrayBuilder, ArrayImpl, DataChunk, ListArrayBuilder, ListValue, StructArrayBuilder,
20    StructValue,
21};
22use risingwave_common::catalog::{TableId, TableOption};
23use risingwave_common::row::RowDeserializer;
24use risingwave_common::types::{DataType, ScalarImpl, StructType};
25use risingwave_common::util::value_encoding::BasicDeserializer;
26use risingwave_common::vector::distance::DistanceMeasurement;
27use risingwave_hummock_sdk::HummockReadEpoch;
28use risingwave_pb::common::PbDistanceType;
29use risingwave_pb::plan_common::PbVectorIndexReaderDesc;
30
31use crate::StateStore;
32use crate::error::StorageResult;
33use crate::store::{NewReadSnapshotOptions, StateStoreReadVector, VectorNearestOptions};
34
35pub struct VectorIndexReader<S> {
36    vector_info_struct_type: StructType,
37    state_store: S,
38    table_id: TableId,
39
40    info_output_indices: Arc<Vec<usize>>,
41    include_distance: bool,
42
43    top_n: usize,
44    measure: DistanceMeasurement,
45    sqrt_distance: bool,
46    deserializer: Arc<BasicDeserializer>,
47    hnsw_ef_search: usize,
48}
49
50impl<S: StateStore> VectorIndexReader<S> {
51    pub fn new(reader_desc: &PbVectorIndexReaderDesc, state_store: S) -> Self {
52        let deserializer = Arc::new(RowDeserializer::new(
53            reader_desc
54                .info_column_desc
55                .iter()
56                .map(|col| DataType::from(col.column_type.clone().unwrap()))
57                .collect_vec(),
58        ));
59
60        let vector_info_struct_type = StructType::new(
61            reader_desc
62                .info_output_indices
63                .iter()
64                .map(|idx| {
65                    let idx = *idx as usize;
66                    (
67                        reader_desc.info_column_desc[idx].name.clone(),
68                        DataType::from(
69                            reader_desc.info_column_desc[idx]
70                                .column_type
71                                .clone()
72                                .unwrap(),
73                        ),
74                    )
75                })
76                .chain(
77                    reader_desc
78                        .include_distance
79                        .then(|| [("__distance".to_owned(), DataType::Float64)].into_iter())
80                        .into_iter()
81                        .flatten(),
82                ),
83        );
84
85        let measure = PbDistanceType::try_from(reader_desc.distance_type)
86            .unwrap()
87            .into();
88
89        let sqrt_distance = match measure {
90            DistanceMeasurement::L2Sqr => true,
91            DistanceMeasurement::L1
92            | DistanceMeasurement::Cosine
93            | DistanceMeasurement::InnerProduct => false,
94        };
95
96        Self {
97            vector_info_struct_type,
98            state_store,
99            table_id: reader_desc.table_id,
100
101            info_output_indices: reader_desc
102                .info_output_indices
103                .iter()
104                .map(|idx| *idx as _)
105                .collect_vec()
106                .into(),
107            include_distance: reader_desc.include_distance,
108            top_n: reader_desc.top_n as usize,
109            measure,
110            sqrt_distance,
111            deserializer,
112            hnsw_ef_search: reader_desc.hnsw_ef_search as usize,
113        }
114    }
115
116    pub fn info_struct_type(&self) -> &StructType {
117        &self.vector_info_struct_type
118    }
119
120    pub async fn new_snapshot(
121        &self,
122        epoch: HummockReadEpoch,
123    ) -> StorageResult<VectorIndexSnapshot<'_, S>> {
124        Ok(VectorIndexSnapshot {
125            reader: self,
126            snapshot: self
127                .state_store
128                .new_read_snapshot(
129                    epoch,
130                    NewReadSnapshotOptions {
131                        table_id: self.table_id,
132                        table_option: TableOption::default(),
133                    },
134                )
135                .await?,
136        })
137    }
138}
139
140pub struct VectorIndexSnapshot<'a, S: StateStore> {
141    reader: &'a VectorIndexReader<S>,
142    snapshot: S::ReadSnapshot,
143}
144
145impl<S: StateStore> VectorIndexSnapshot<'_, S> {
146    pub async fn query_expand_chunk(
147        &self,
148        chunk: DataChunk,
149        vector_column_idx: usize,
150    ) -> StorageResult<DataChunk> {
151        let sqrt_distance = self.reader.sqrt_distance;
152        let struct_len = self.reader.vector_info_struct_type.len();
153        let include_distance = self.reader.include_distance;
154
155        let mut vector_info_columns_builder = ListArrayBuilder::with_type(
156            chunk.cardinality(),
157            DataType::list(DataType::Struct(self.reader.info_struct_type().clone())),
158        );
159        let (mut columns, vis) = chunk.into_parts();
160        let vector_column = columns[vector_column_idx].as_vector();
161        for (idx, vis) in vis.iter().enumerate() {
162            if vis && let Some(vector) = vector_column.value_at(idx) {
163                let deserializer = self.reader.deserializer.clone();
164                let info_output_indices = self.reader.info_output_indices.clone();
165                let row_results: Vec<StorageResult<StructValue>> = self
166                    .snapshot
167                    .nearest(
168                        vector,
169                        VectorNearestOptions {
170                            top_n: self.reader.top_n,
171                            measure: self.reader.measure,
172                            hnsw_ef_search: self.reader.hnsw_ef_search,
173                        },
174                        move |_vec, distance, value| {
175                            let mut values = Vec::with_capacity(deserializer.data_types().len());
176                            deserializer.deserialize_to(value, &mut values)?;
177                            let mut info = Vec::with_capacity(struct_len);
178                            for idx in &*info_output_indices {
179                                info.push(values[*idx].clone());
180                            }
181                            if include_distance {
182                                let distance = if sqrt_distance {
183                                    distance.sqrt()
184                                } else {
185                                    distance
186                                };
187                                info.push(Some(ScalarImpl::Float64(distance.into())));
188                            }
189                            Ok(StructValue::new(info))
190                        },
191                    )
192                    .await?;
193                let mut struct_array_builder = StructArrayBuilder::with_type(
194                    row_results.len(),
195                    DataType::Struct(self.reader.vector_info_struct_type.clone()),
196                );
197                for row in row_results {
198                    let row = row?;
199                    struct_array_builder.append_owned(Some(row));
200                }
201                let struct_array = struct_array_builder.finish();
202
203                let value = ListValue::new(ArrayImpl::Struct(struct_array));
204                vector_info_columns_builder.append_owned(Some(value));
205            } else {
206                vector_info_columns_builder.append_null();
207            }
208        }
209        columns.push(ArrayImpl::List(vector_info_columns_builder.finish()).into());
210
211        Ok(DataChunk::new(columns, vis))
212    }
213}