risingwave_storage/table/batch_table/
vector_index_reader.rs1use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::array::{
19 Array, ArrayBuilder, ArrayImpl, DataChunk, ListArrayBuilder, ListValue, StructArrayBuilder,
20 StructValue,
21};
22use risingwave_common::catalog::{TableId, TableOption};
23use risingwave_common::row::RowDeserializer;
24use risingwave_common::types::{DataType, ScalarImpl, StructType};
25use risingwave_common::util::value_encoding::BasicDeserializer;
26use risingwave_common::vector::distance::DistanceMeasurement;
27use risingwave_hummock_sdk::HummockReadEpoch;
28use risingwave_pb::common::PbDistanceType;
29use risingwave_pb::plan_common::PbVectorIndexReaderDesc;
30
31use crate::StateStore;
32use crate::error::StorageResult;
33use crate::store::{NewReadSnapshotOptions, StateStoreReadVector, VectorNearestOptions};
34
35pub struct VectorIndexReader<S> {
36 vector_info_struct_type: StructType,
37 state_store: S,
38 table_id: TableId,
39
40 info_output_indices: Arc<Vec<usize>>,
41 include_distance: bool,
42
43 top_n: usize,
44 measure: DistanceMeasurement,
45 sqrt_distance: bool,
46 deserializer: Arc<BasicDeserializer>,
47 hnsw_ef_search: usize,
48}
49
50impl<S: StateStore> VectorIndexReader<S> {
51 pub fn new(reader_desc: &PbVectorIndexReaderDesc, state_store: S) -> Self {
52 let deserializer = Arc::new(RowDeserializer::new(
53 reader_desc
54 .info_column_desc
55 .iter()
56 .map(|col| DataType::from(col.column_type.clone().unwrap()))
57 .collect_vec(),
58 ));
59
60 let vector_info_struct_type = StructType::new(
61 reader_desc
62 .info_output_indices
63 .iter()
64 .map(|idx| {
65 let idx = *idx as usize;
66 (
67 reader_desc.info_column_desc[idx].name.clone(),
68 DataType::from(
69 reader_desc.info_column_desc[idx]
70 .column_type
71 .clone()
72 .unwrap(),
73 ),
74 )
75 })
76 .chain(
77 reader_desc
78 .include_distance
79 .then(|| [("__distance".to_owned(), DataType::Float64)].into_iter())
80 .into_iter()
81 .flatten(),
82 ),
83 );
84
85 let measure = PbDistanceType::try_from(reader_desc.distance_type)
86 .unwrap()
87 .into();
88
89 let sqrt_distance = match measure {
90 DistanceMeasurement::L2Sqr => true,
91 DistanceMeasurement::L1
92 | DistanceMeasurement::Cosine
93 | DistanceMeasurement::InnerProduct => false,
94 };
95
96 Self {
97 vector_info_struct_type,
98 state_store,
99 table_id: reader_desc.table_id,
100
101 info_output_indices: reader_desc
102 .info_output_indices
103 .iter()
104 .map(|idx| *idx as _)
105 .collect_vec()
106 .into(),
107 include_distance: reader_desc.include_distance,
108 top_n: reader_desc.top_n as usize,
109 measure,
110 sqrt_distance,
111 deserializer,
112 hnsw_ef_search: reader_desc.hnsw_ef_search as usize,
113 }
114 }
115
116 pub fn info_struct_type(&self) -> &StructType {
117 &self.vector_info_struct_type
118 }
119
120 pub async fn new_snapshot(
121 &self,
122 epoch: HummockReadEpoch,
123 ) -> StorageResult<VectorIndexSnapshot<'_, S>> {
124 Ok(VectorIndexSnapshot {
125 reader: self,
126 snapshot: self
127 .state_store
128 .new_read_snapshot(
129 epoch,
130 NewReadSnapshotOptions {
131 table_id: self.table_id,
132 table_option: TableOption::default(),
133 },
134 )
135 .await?,
136 })
137 }
138}
139
140pub struct VectorIndexSnapshot<'a, S: StateStore> {
141 reader: &'a VectorIndexReader<S>,
142 snapshot: S::ReadSnapshot,
143}
144
145impl<S: StateStore> VectorIndexSnapshot<'_, S> {
146 pub async fn query_expand_chunk(
147 &self,
148 chunk: DataChunk,
149 vector_column_idx: usize,
150 ) -> StorageResult<DataChunk> {
151 let sqrt_distance = self.reader.sqrt_distance;
152 let struct_len = self.reader.vector_info_struct_type.len();
153 let include_distance = self.reader.include_distance;
154
155 let mut vector_info_columns_builder = ListArrayBuilder::with_type(
156 chunk.cardinality(),
157 DataType::list(DataType::Struct(self.reader.info_struct_type().clone())),
158 );
159 let (mut columns, vis) = chunk.into_parts();
160 let vector_column = columns[vector_column_idx].as_vector();
161 for (idx, vis) in vis.iter().enumerate() {
162 if vis && let Some(vector) = vector_column.value_at(idx) {
163 let deserializer = self.reader.deserializer.clone();
164 let info_output_indices = self.reader.info_output_indices.clone();
165 let row_results: Vec<StorageResult<StructValue>> = self
166 .snapshot
167 .nearest(
168 vector,
169 VectorNearestOptions {
170 top_n: self.reader.top_n,
171 measure: self.reader.measure,
172 hnsw_ef_search: self.reader.hnsw_ef_search,
173 },
174 move |_vec, distance, value| {
175 let mut values = Vec::with_capacity(deserializer.data_types().len());
176 deserializer.deserialize_to(value, &mut values)?;
177 let mut info = Vec::with_capacity(struct_len);
178 for idx in &*info_output_indices {
179 info.push(values[*idx].clone());
180 }
181 if include_distance {
182 let distance = if sqrt_distance {
183 distance.sqrt()
184 } else {
185 distance
186 };
187 info.push(Some(ScalarImpl::Float64(distance.into())));
188 }
189 Ok(StructValue::new(info))
190 },
191 )
192 .await?;
193 let mut struct_array_builder = StructArrayBuilder::with_type(
194 row_results.len(),
195 DataType::Struct(self.reader.vector_info_struct_type.clone()),
196 );
197 for row in row_results {
198 let row = row?;
199 struct_array_builder.append_owned(Some(row));
200 }
201 let struct_array = struct_array_builder.finish();
202
203 let value = ListValue::new(ArrayImpl::Struct(struct_array));
204 vector_info_columns_builder.append_owned(Some(value));
205 } else {
206 vector_info_columns_builder.append_null();
207 }
208 }
209 columns.push(ArrayImpl::List(vector_info_columns_builder.finish()).into());
210
211 Ok(DataChunk::new(columns, vis))
212 }
213}