risingwave_batch_executors/executor/
vector_index_nearest.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use futures::pin_mut;
18use futures::prelude::stream::StreamExt;
19use futures_async_stream::try_stream;
20use futures_util::TryStreamExt;
21use itertools::Itertools;
22use risingwave_common::array::{
23    Array, ArrayBuilder, ArrayImpl, DataChunk, ListArrayBuilder, ListValue, StructArrayBuilder,
24    StructValue,
25};
26use risingwave_common::catalog::{Field, Schema, TableId};
27use risingwave_common::row::RowDeserializer;
28use risingwave_common::types::{DataType, ScalarImpl, ScalarRef, StructType};
29use risingwave_common::util::value_encoding::BasicDeserializer;
30use risingwave_common::vector::distance::DistanceMeasurement;
31use risingwave_pb::batch_plan::plan_node::NodeBody;
32use risingwave_pb::common::{BatchQueryEpoch, PbDistanceType};
33use risingwave_storage::store::{
34    NewReadSnapshotOptions, StateStoreReadVector, VectorNearestOptions,
35};
36use risingwave_storage::{StateStore, dispatch_state_store};
37
38use super::{BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder};
39use crate::error::{BatchError, Result};
40
41pub struct VectorIndexNearestExecutor<S: StateStore> {
42    identity: String,
43    schema: Schema,
44    vector_info_struct_type: StructType,
45
46    input: BoxedExecutor,
47
48    state_store: S,
49    table_id: TableId,
50    epoch: BatchQueryEpoch,
51    vector_column_idx: usize,
52    top_n: usize,
53    measure: DistanceMeasurement,
54    deserializer: BasicDeserializer,
55
56    hnsw_ef_search: usize,
57}
58
59pub struct VectorIndexNearestExecutorBuilder {}
60
61impl BoxedExecutorBuilder for VectorIndexNearestExecutorBuilder {
62    async fn new_boxed_executor(
63        source: &ExecutorBuilder<'_>,
64        inputs: Vec<BoxedExecutor>,
65    ) -> Result<BoxedExecutor> {
66        ensure!(
67            inputs.len() == 1,
68            "VectorIndexNearest should have an input executor!"
69        );
70        let [input]: [_; 1] = inputs.try_into().unwrap();
71        let vector_index_nearest_node = try_match_expand!(
72            source.plan_node().get_node_body().unwrap(),
73            NodeBody::VectorIndexNearest
74        )?;
75
76        let deserializer = RowDeserializer::new(
77            vector_index_nearest_node
78                .info_column_desc
79                .iter()
80                .map(|col| DataType::from(col.column_type.clone().unwrap()))
81                .collect_vec(),
82        );
83
84        let vector_info_struct_type = StructType::new(
85            vector_index_nearest_node
86                .info_column_desc
87                .iter()
88                .map(|col| {
89                    (
90                        col.name.clone(),
91                        DataType::from(col.column_type.clone().unwrap()),
92                    )
93                })
94                .chain([("__distance".to_owned(), DataType::Float64)]),
95        );
96
97        let mut schema = input.schema().clone();
98        schema.fields.push(Field::new(
99            "vector_info",
100            DataType::List(DataType::Struct(vector_info_struct_type.clone()).into()),
101        ));
102
103        let epoch = source.epoch();
104        dispatch_state_store!(source.context().state_store(), state_store, {
105            Ok(Box::new(VectorIndexNearestExecutor {
106                identity: source.plan_node().get_identity().clone(),
107                schema,
108                vector_info_struct_type,
109                input,
110                state_store,
111                table_id: vector_index_nearest_node.table_id.into(),
112                epoch,
113                vector_column_idx: vector_index_nearest_node.vector_column_idx as usize,
114                top_n: vector_index_nearest_node.top_n as usize,
115                measure: PbDistanceType::try_from(vector_index_nearest_node.distance_type)
116                    .unwrap()
117                    .into(),
118                deserializer,
119                hnsw_ef_search: vector_index_nearest_node.hnsw_ef_search as usize,
120            }))
121        })
122    }
123}
124impl<S: StateStore> Executor for VectorIndexNearestExecutor<S> {
125    fn schema(&self) -> &Schema {
126        &self.schema
127    }
128
129    fn identity(&self) -> &str {
130        &self.identity
131    }
132
133    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
134        self.do_execute().boxed()
135    }
136}
137
138impl<S: StateStore> VectorIndexNearestExecutor<S> {
139    #[try_stream(ok = DataChunk, error = BatchError)]
140    async fn do_execute(self: Box<Self>) {
141        let Self {
142            state_store,
143            table_id,
144            epoch,
145            vector_info_struct_type,
146            input,
147            vector_column_idx,
148            top_n,
149            measure,
150            deserializer,
151            hnsw_ef_search,
152            ..
153        } = *self;
154
155        let input = input.execute();
156        pin_mut!(input);
157
158        let read_snapshot: S::ReadSnapshot = state_store
159            .new_read_snapshot(epoch.into(), NewReadSnapshotOptions { table_id })
160            .await?;
161
162        let deserializer = Arc::new(deserializer);
163        let sqrt_distance = match &self.measure {
164            DistanceMeasurement::L2Sqr => true,
165            DistanceMeasurement::L1
166            | DistanceMeasurement::Cosine
167            | DistanceMeasurement::InnerProduct => false,
168        };
169
170        while let Some(chunk) = input.try_next().await? {
171            let mut vector_info_columns_builder = ListArrayBuilder::with_type(
172                chunk.cardinality(),
173                DataType::List(DataType::Struct(vector_info_struct_type.clone()).into()),
174            );
175            let (mut columns, vis) = chunk.into_parts();
176            let vector_column = columns[vector_column_idx].as_vector();
177            for (idx, vis) in vis.iter().enumerate() {
178                if vis && let Some(vector) = vector_column.value_at(idx) {
179                    let deserializer = deserializer.clone();
180                    let row_results: Vec<Result<StructValue>> = read_snapshot
181                        .nearest(
182                            vector.to_owned_scalar(),
183                            VectorNearestOptions {
184                                top_n,
185                                measure,
186                                hnsw_ef_search,
187                            },
188                            move |_vec, distance, value| {
189                                let mut values =
190                                    Vec::with_capacity(deserializer.data_types().len() + 1);
191                                deserializer.deserialize_to(value, &mut values)?;
192                                let distance = if sqrt_distance {
193                                    distance.sqrt()
194                                } else {
195                                    distance
196                                };
197                                values.push(Some(ScalarImpl::Float64(distance.into())));
198                                Ok(StructValue::new(values))
199                            },
200                        )
201                        .await?;
202                    let mut struct_array_builder = StructArrayBuilder::with_type(
203                        row_results.len(),
204                        DataType::Struct(vector_info_struct_type.clone()),
205                    );
206                    for row in row_results {
207                        let row = row?;
208                        struct_array_builder.append_owned(Some(row));
209                    }
210                    let struct_array = struct_array_builder.finish();
211                    vector_info_columns_builder
212                        .append_owned(Some(ListValue::new(ArrayImpl::Struct(struct_array))));
213                } else {
214                    vector_info_columns_builder.append_null();
215                }
216            }
217            columns.push(ArrayImpl::List(vector_info_columns_builder.finish()).into());
218
219            yield DataChunk::new(columns, vis);
220        }
221    }
222}