risingwave_storage/table/batch_table/
vector_index_reader.rs1use std::sync::Arc;
16
17use itertools::Itertools;
18use risingwave_common::array::{
19 Array, ArrayBuilder, ArrayImpl, DataChunk, ListArrayBuilder, ListValue, StructArrayBuilder,
20 StructValue,
21};
22use risingwave_common::catalog::TableId;
23use risingwave_common::row::RowDeserializer;
24use risingwave_common::types::{DataType, ScalarImpl, StructType};
25use risingwave_common::util::value_encoding::BasicDeserializer;
26use risingwave_common::vector::distance::DistanceMeasurement;
27use risingwave_hummock_sdk::HummockReadEpoch;
28use risingwave_pb::common::PbDistanceType;
29use risingwave_pb::plan_common::PbVectorIndexReaderDesc;
30
31use crate::StateStore;
32use crate::error::StorageResult;
33use crate::store::{NewReadSnapshotOptions, StateStoreReadVector, VectorNearestOptions};
34
35pub struct VectorIndexReader<S> {
36 vector_info_struct_type: StructType,
37 state_store: S,
38 table_id: TableId,
39
40 info_output_indices: Arc<Vec<usize>>,
41 include_distance: bool,
42
43 top_n: usize,
44 measure: DistanceMeasurement,
45 sqrt_distance: bool,
46 deserializer: Arc<BasicDeserializer>,
47 hnsw_ef_search: usize,
48}
49
50impl<S: StateStore> VectorIndexReader<S> {
51 pub fn new(reader_desc: &PbVectorIndexReaderDesc, state_store: S) -> Self {
52 let deserializer = Arc::new(RowDeserializer::new(
53 reader_desc
54 .info_column_desc
55 .iter()
56 .map(|col| DataType::from(col.column_type.clone().unwrap()))
57 .collect_vec(),
58 ));
59
60 let vector_info_struct_type = StructType::new(
61 reader_desc
62 .info_output_indices
63 .iter()
64 .map(|idx| {
65 let idx = *idx as usize;
66 (
67 reader_desc.info_column_desc[idx].name.clone(),
68 DataType::from(
69 reader_desc.info_column_desc[idx]
70 .column_type
71 .clone()
72 .unwrap(),
73 ),
74 )
75 })
76 .chain(
77 reader_desc
78 .include_distance
79 .then(|| [("__distance".to_owned(), DataType::Float64)].into_iter())
80 .into_iter()
81 .flatten(),
82 ),
83 );
84
85 let measure = PbDistanceType::try_from(reader_desc.distance_type)
86 .unwrap()
87 .into();
88
89 let sqrt_distance = match measure {
90 DistanceMeasurement::L2Sqr => true,
91 DistanceMeasurement::L1
92 | DistanceMeasurement::Cosine
93 | DistanceMeasurement::InnerProduct => false,
94 };
95
96 Self {
97 vector_info_struct_type,
98 state_store,
99 table_id: reader_desc.table_id,
100
101 info_output_indices: reader_desc
102 .info_output_indices
103 .iter()
104 .map(|idx| *idx as _)
105 .collect_vec()
106 .into(),
107 include_distance: reader_desc.include_distance,
108 top_n: reader_desc.top_n as usize,
109 measure,
110 sqrt_distance,
111 deserializer,
112 hnsw_ef_search: reader_desc.hnsw_ef_search as usize,
113 }
114 }
115
116 pub fn info_struct_type(&self) -> &StructType {
117 &self.vector_info_struct_type
118 }
119
120 pub async fn new_snapshot(
121 &self,
122 epoch: HummockReadEpoch,
123 ) -> StorageResult<VectorIndexSnapshot<'_, S>> {
124 Ok(VectorIndexSnapshot {
125 reader: self,
126 snapshot: self
127 .state_store
128 .new_read_snapshot(
129 epoch,
130 NewReadSnapshotOptions {
131 table_id: self.table_id,
132 },
133 )
134 .await?,
135 })
136 }
137}
138
139pub struct VectorIndexSnapshot<'a, S: StateStore> {
140 reader: &'a VectorIndexReader<S>,
141 snapshot: S::ReadSnapshot,
142}
143
144impl<S: StateStore> VectorIndexSnapshot<'_, S> {
145 pub async fn query_expand_chunk(
146 &self,
147 chunk: DataChunk,
148 vector_column_idx: usize,
149 ) -> StorageResult<DataChunk> {
150 let sqrt_distance = self.reader.sqrt_distance;
151 let struct_len = self.reader.vector_info_struct_type.len();
152 let include_distance = self.reader.include_distance;
153
154 let mut vector_info_columns_builder = ListArrayBuilder::with_type(
155 chunk.cardinality(),
156 DataType::list(DataType::Struct(self.reader.info_struct_type().clone())),
157 );
158 let (mut columns, vis) = chunk.into_parts();
159 let vector_column = columns[vector_column_idx].as_vector();
160 for (idx, vis) in vis.iter().enumerate() {
161 if vis && let Some(vector) = vector_column.value_at(idx) {
162 let deserializer = self.reader.deserializer.clone();
163 let info_output_indices = self.reader.info_output_indices.clone();
164 let row_results: Vec<StorageResult<StructValue>> = self
165 .snapshot
166 .nearest(
167 vector,
168 VectorNearestOptions {
169 top_n: self.reader.top_n,
170 measure: self.reader.measure,
171 hnsw_ef_search: self.reader.hnsw_ef_search,
172 },
173 move |_vec, distance, value| {
174 let mut values = Vec::with_capacity(deserializer.data_types().len());
175 deserializer.deserialize_to(value, &mut values)?;
176 let mut info = Vec::with_capacity(struct_len);
177 for idx in &*info_output_indices {
178 info.push(values[*idx].clone());
179 }
180 if include_distance {
181 let distance = if sqrt_distance {
182 distance.sqrt()
183 } else {
184 distance
185 };
186 info.push(Some(ScalarImpl::Float64(distance.into())));
187 }
188 Ok(StructValue::new(info))
189 },
190 )
191 .await?;
192 let mut struct_array_builder = StructArrayBuilder::with_type(
193 row_results.len(),
194 DataType::Struct(self.reader.vector_info_struct_type.clone()),
195 );
196 for row in row_results {
197 let row = row?;
198 struct_array_builder.append_owned(Some(row));
199 }
200 let struct_array = struct_array_builder.finish();
201
202 let value = ListValue::new(ArrayImpl::Struct(struct_array));
203 vector_info_columns_builder.append_owned(Some(value));
204 } else {
205 vector_info_columns_builder.append_null();
206 }
207 }
208 columns.push(ArrayImpl::List(vector_info_columns_builder.finish()).into());
209
210 Ok(DataChunk::new(columns, vis))
211 }
212}