risingwave_storage/table/batch_table/
vector_index_reader.rs1use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::array::{
19    Array, ArrayBuilder, ArrayImpl, DataChunk, ListArrayBuilder, ListValue, StructArrayBuilder,
20    StructValue,
21};
22use risingwave_common::catalog::TableId;
23use risingwave_common::row::RowDeserializer;
24use risingwave_common::types::{DataType, ScalarImpl, StructType};
25use risingwave_common::util::value_encoding::BasicDeserializer;
26use risingwave_common::vector::distance::DistanceMeasurement;
27use risingwave_hummock_sdk::HummockReadEpoch;
28use risingwave_pb::common::PbDistanceType;
29use risingwave_pb::plan_common::PbVectorIndexReaderDesc;
30
31use crate::StateStore;
32use crate::error::StorageResult;
33use crate::store::{NewReadSnapshotOptions, StateStoreReadVector, VectorNearestOptions};
34
35pub struct VectorIndexReader<S> {
36    vector_info_struct_type: StructType,
37    state_store: S,
38    table_id: TableId,
39
40    info_output_indices: Arc<Vec<usize>>,
41    include_distance: bool,
42
43    top_n: usize,
44    measure: DistanceMeasurement,
45    sqrt_distance: bool,
46    deserializer: Arc<BasicDeserializer>,
47    hnsw_ef_search: usize,
48}
49
50impl<S: StateStore> VectorIndexReader<S> {
51    pub fn new(reader_desc: &PbVectorIndexReaderDesc, state_store: S) -> Self {
52        let deserializer = Arc::new(RowDeserializer::new(
53            reader_desc
54                .info_column_desc
55                .iter()
56                .map(|col| DataType::from(col.column_type.clone().unwrap()))
57                .collect_vec(),
58        ));
59
60        let vector_info_struct_type = StructType::new(
61            reader_desc
62                .info_output_indices
63                .iter()
64                .map(|idx| {
65                    let idx = *idx as usize;
66                    (
67                        reader_desc.info_column_desc[idx].name.clone(),
68                        DataType::from(
69                            reader_desc.info_column_desc[idx]
70                                .column_type
71                                .clone()
72                                .unwrap(),
73                        ),
74                    )
75                })
76                .chain(
77                    reader_desc
78                        .include_distance
79                        .then(|| [("__distance".to_owned(), DataType::Float64)].into_iter())
80                        .into_iter()
81                        .flatten(),
82                ),
83        );
84
85        let measure = PbDistanceType::try_from(reader_desc.distance_type)
86            .unwrap()
87            .into();
88
89        let sqrt_distance = match measure {
90            DistanceMeasurement::L2Sqr => true,
91            DistanceMeasurement::L1
92            | DistanceMeasurement::Cosine
93            | DistanceMeasurement::InnerProduct => false,
94        };
95
96        Self {
97            vector_info_struct_type,
98            state_store,
99            table_id: reader_desc.table_id.into(),
100
101            info_output_indices: reader_desc
102                .info_output_indices
103                .iter()
104                .map(|idx| *idx as _)
105                .collect_vec()
106                .into(),
107            include_distance: reader_desc.include_distance,
108            top_n: reader_desc.top_n as usize,
109            measure,
110            sqrt_distance,
111            deserializer,
112            hnsw_ef_search: reader_desc.hnsw_ef_search as usize,
113        }
114    }
115
116    pub fn info_struct_type(&self) -> &StructType {
117        &self.vector_info_struct_type
118    }
119
120    pub async fn new_snapshot(
121        &self,
122        epoch: HummockReadEpoch,
123    ) -> StorageResult<VectorIndexSnapshot<'_, S>> {
124        Ok(VectorIndexSnapshot {
125            reader: self,
126            snapshot: self
127                .state_store
128                .new_read_snapshot(
129                    epoch,
130                    NewReadSnapshotOptions {
131                        table_id: self.table_id,
132                    },
133                )
134                .await?,
135        })
136    }
137}
138
139pub struct VectorIndexSnapshot<'a, S: StateStore> {
140    reader: &'a VectorIndexReader<S>,
141    snapshot: S::ReadSnapshot,
142}
143
144impl<S: StateStore> VectorIndexSnapshot<'_, S> {
145    pub async fn query_expand_chunk(
146        &self,
147        chunk: DataChunk,
148        vector_column_idx: usize,
149    ) -> StorageResult<DataChunk> {
150        let sqrt_distance = self.reader.sqrt_distance;
151        let struct_len = self.reader.vector_info_struct_type.len();
152        let include_distance = self.reader.include_distance;
153
154        let mut vector_info_columns_builder = ListArrayBuilder::with_type(
155            chunk.cardinality(),
156            DataType::list(DataType::Struct(self.reader.info_struct_type().clone())),
157        );
158        let (mut columns, vis) = chunk.into_parts();
159        let vector_column = columns[vector_column_idx].as_vector();
160        for (idx, vis) in vis.iter().enumerate() {
161            if vis && let Some(vector) = vector_column.value_at(idx) {
162                let deserializer = self.reader.deserializer.clone();
163                let info_output_indices = self.reader.info_output_indices.clone();
164                let row_results: Vec<StorageResult<StructValue>> = self
165                    .snapshot
166                    .nearest(
167                        vector,
168                        VectorNearestOptions {
169                            top_n: self.reader.top_n,
170                            measure: self.reader.measure,
171                            hnsw_ef_search: self.reader.hnsw_ef_search,
172                        },
173                        move |_vec, distance, value| {
174                            let mut values = Vec::with_capacity(deserializer.data_types().len());
175                            deserializer.deserialize_to(value, &mut values)?;
176                            let mut info = Vec::with_capacity(struct_len);
177                            for idx in &*info_output_indices {
178                                info.push(values[*idx].clone());
179                            }
180                            if include_distance {
181                                let distance = if sqrt_distance {
182                                    distance.sqrt()
183                                } else {
184                                    distance
185                                };
186                                info.push(Some(ScalarImpl::Float64(distance.into())));
187                            }
188                            Ok(StructValue::new(info))
189                        },
190                    )
191                    .await?;
192                let mut struct_array_builder = StructArrayBuilder::with_type(
193                    row_results.len(),
194                    DataType::Struct(self.reader.vector_info_struct_type.clone()),
195                );
196                for row in row_results {
197                    let row = row?;
198                    struct_array_builder.append_owned(Some(row));
199                }
200                let struct_array = struct_array_builder.finish();
201
202                let value = ListValue::new(ArrayImpl::Struct(struct_array));
203                vector_info_columns_builder.append_owned(Some(value));
204            } else {
205                vector_info_columns_builder.append_null();
206            }
207        }
208        columns.push(ArrayImpl::List(vector_info_columns_builder.finish()).into());
209
210        Ok(DataChunk::new(columns, vis))
211    }
212}