risingwave_batch_executors/executor/
vector_index_nearest.rs1use std::sync::Arc;
16
17use futures::pin_mut;
18use futures::prelude::stream::StreamExt;
19use futures_async_stream::try_stream;
20use futures_util::TryStreamExt;
21use itertools::Itertools;
22use risingwave_common::array::{
23 Array, ArrayBuilder, ArrayImpl, DataChunk, ListArrayBuilder, ListValue, StructArrayBuilder,
24 StructValue,
25};
26use risingwave_common::catalog::{Field, Schema, TableId};
27use risingwave_common::row::RowDeserializer;
28use risingwave_common::types::{DataType, ScalarImpl, ScalarRef, StructType};
29use risingwave_common::util::value_encoding::BasicDeserializer;
30use risingwave_common::vector::distance::DistanceMeasurement;
31use risingwave_pb::batch_plan::plan_node::NodeBody;
32use risingwave_pb::common::{BatchQueryEpoch, PbDistanceType};
33use risingwave_storage::store::{
34 NewReadSnapshotOptions, StateStoreReadVector, VectorNearestOptions,
35};
36use risingwave_storage::{StateStore, dispatch_state_store};
37
38use super::{BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder};
39use crate::error::{BatchError, Result};
40
41pub struct VectorIndexNearestExecutor<S: StateStore> {
42 identity: String,
43 schema: Schema,
44 vector_info_struct_type: StructType,
45
46 input: BoxedExecutor,
47
48 state_store: S,
49 table_id: TableId,
50 epoch: BatchQueryEpoch,
51 vector_column_idx: usize,
52 top_n: usize,
53 measure: DistanceMeasurement,
54 deserializer: BasicDeserializer,
55
56 hnsw_ef_search: usize,
57}
58
59pub struct VectorIndexNearestExecutorBuilder {}
60
61impl BoxedExecutorBuilder for VectorIndexNearestExecutorBuilder {
62 async fn new_boxed_executor(
63 source: &ExecutorBuilder<'_>,
64 inputs: Vec<BoxedExecutor>,
65 ) -> Result<BoxedExecutor> {
66 ensure!(
67 inputs.len() == 1,
68 "VectorIndexNearest should have an input executor!"
69 );
70 let [input]: [_; 1] = inputs.try_into().unwrap();
71 let vector_index_nearest_node = try_match_expand!(
72 source.plan_node().get_node_body().unwrap(),
73 NodeBody::VectorIndexNearest
74 )?;
75
76 let deserializer = RowDeserializer::new(
77 vector_index_nearest_node
78 .info_column_desc
79 .iter()
80 .map(|col| DataType::from(col.column_type.clone().unwrap()))
81 .collect_vec(),
82 );
83
84 let vector_info_struct_type = StructType::new(
85 vector_index_nearest_node
86 .info_column_desc
87 .iter()
88 .map(|col| {
89 (
90 col.name.clone(),
91 DataType::from(col.column_type.clone().unwrap()),
92 )
93 })
94 .chain([("__distance".to_owned(), DataType::Float64)]),
95 );
96
97 let mut schema = input.schema().clone();
98 schema.fields.push(Field::new(
99 "vector_info",
100 DataType::List(DataType::Struct(vector_info_struct_type.clone()).into()),
101 ));
102
103 let epoch = source.epoch();
104 dispatch_state_store!(source.context().state_store(), state_store, {
105 Ok(Box::new(VectorIndexNearestExecutor {
106 identity: source.plan_node().get_identity().clone(),
107 schema,
108 vector_info_struct_type,
109 input,
110 state_store,
111 table_id: vector_index_nearest_node.table_id.into(),
112 epoch,
113 vector_column_idx: vector_index_nearest_node.vector_column_idx as usize,
114 top_n: vector_index_nearest_node.top_n as usize,
115 measure: PbDistanceType::try_from(vector_index_nearest_node.distance_type)
116 .unwrap()
117 .into(),
118 deserializer,
119 hnsw_ef_search: vector_index_nearest_node.hnsw_ef_search as usize,
120 }))
121 })
122 }
123}
124impl<S: StateStore> Executor for VectorIndexNearestExecutor<S> {
125 fn schema(&self) -> &Schema {
126 &self.schema
127 }
128
129 fn identity(&self) -> &str {
130 &self.identity
131 }
132
133 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
134 self.do_execute().boxed()
135 }
136}
137
138impl<S: StateStore> VectorIndexNearestExecutor<S> {
139 #[try_stream(ok = DataChunk, error = BatchError)]
140 async fn do_execute(self: Box<Self>) {
141 let Self {
142 state_store,
143 table_id,
144 epoch,
145 vector_info_struct_type,
146 input,
147 vector_column_idx,
148 top_n,
149 measure,
150 deserializer,
151 hnsw_ef_search,
152 ..
153 } = *self;
154
155 let input = input.execute();
156 pin_mut!(input);
157
158 let read_snapshot: S::ReadSnapshot = state_store
159 .new_read_snapshot(epoch.into(), NewReadSnapshotOptions { table_id })
160 .await?;
161
162 let deserializer = Arc::new(deserializer);
163 let sqrt_distance = match &self.measure {
164 DistanceMeasurement::L2Sqr => true,
165 DistanceMeasurement::L1
166 | DistanceMeasurement::Cosine
167 | DistanceMeasurement::InnerProduct => false,
168 };
169
170 while let Some(chunk) = input.try_next().await? {
171 let mut vector_info_columns_builder = ListArrayBuilder::with_type(
172 chunk.cardinality(),
173 DataType::List(DataType::Struct(vector_info_struct_type.clone()).into()),
174 );
175 let (mut columns, vis) = chunk.into_parts();
176 let vector_column = columns[vector_column_idx].as_vector();
177 for (idx, vis) in vis.iter().enumerate() {
178 if vis && let Some(vector) = vector_column.value_at(idx) {
179 let deserializer = deserializer.clone();
180 let row_results: Vec<Result<StructValue>> = read_snapshot
181 .nearest(
182 vector.to_owned_scalar(),
183 VectorNearestOptions {
184 top_n,
185 measure,
186 hnsw_ef_search,
187 },
188 move |_vec, distance, value| {
189 let mut values =
190 Vec::with_capacity(deserializer.data_types().len() + 1);
191 deserializer.deserialize_to(value, &mut values)?;
192 let distance = if sqrt_distance {
193 distance.sqrt()
194 } else {
195 distance
196 };
197 values.push(Some(ScalarImpl::Float64(distance.into())));
198 Ok(StructValue::new(values))
199 },
200 )
201 .await?;
202 let mut struct_array_builder = StructArrayBuilder::with_type(
203 row_results.len(),
204 DataType::Struct(vector_info_struct_type.clone()),
205 );
206 for row in row_results {
207 let row = row?;
208 struct_array_builder.append_owned(Some(row));
209 }
210 let struct_array = struct_array_builder.finish();
211 vector_info_columns_builder
212 .append_owned(Some(ListValue::new(ArrayImpl::Struct(struct_array))));
213 } else {
214 vector_info_columns_builder.append_null();
215 }
216 }
217 columns.push(ArrayImpl::List(vector_info_columns_builder.finish()).into());
218
219 yield DataChunk::new(columns, vis);
220 }
221 }
222}