risingwave_frontend/optimizer/plan_node/generic/
vector_index_lookup_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use pretty_xmlish::Pretty;
16use risingwave_common::catalog::{ColumnDesc, Field, Schema};
17use risingwave_common::types::{DataType, StructType};
18use risingwave_pb::common::PbDistanceType;
19use risingwave_pb::plan_common::PbVectorIndexReaderDesc;
20use risingwave_sqlparser::ast::AsOf;
21
22use crate::OptimizerContextRef;
23use crate::catalog::TableId;
24use crate::expr::{ExprDisplay, ExprImpl, InputRef};
25use crate::optimizer::plan_node::generic::{GenericPlanNode, GenericPlanRef};
26use crate::optimizer::property::FunctionalDependencySet;
27
28#[derive(Debug, Clone, educe::Educe)]
29#[educe(Hash, PartialEq, Eq)]
30pub struct VectorIndexLookupJoin<PlanRef> {
31    pub input: PlanRef,
32    pub top_n: u64,
33    pub distance_type: PbDistanceType,
34    pub index_name: String,
35    pub index_table_id: TableId,
36    pub info_column_desc: Vec<ColumnDesc>,
37    pub info_output_indices: Vec<usize>,
38    pub include_distance: bool,
39    pub as_of: Option<AsOf>,
40
41    pub vector_column_idx: usize,
42    pub hnsw_ef_search: Option<usize>,
43    #[educe(Hash(ignore), Eq(ignore))]
44    pub ctx: OptimizerContextRef,
45}
46
47impl<PlanRef: GenericPlanRef> GenericPlanNode for VectorIndexLookupJoin<PlanRef> {
48    fn functional_dependency(&self) -> FunctionalDependencySet {
49        // TODO: copy the one of input, and extend with an extra column with no dependency
50        FunctionalDependencySet::new(
51            self.info_output_indices.len() + if self.include_distance { 1 } else { 0 },
52        )
53    }
54
55    fn schema(&self) -> Schema {
56        let mut schema = self.input.schema().clone();
57        schema.fields.push(Field::new(
58            "vector_info",
59            DataType::list(
60                StructType::new(
61                    self.info_output_indices
62                        .iter()
63                        .map(|idx| {
64                            (
65                                self.info_column_desc[*idx].name.clone(),
66                                self.info_column_desc[*idx].data_type.clone(),
67                            )
68                        })
69                        .chain(
70                            self.include_distance
71                                .then(|| [("__distance".to_owned(), DataType::Float64)].into_iter())
72                                .into_iter()
73                                .flatten(),
74                        ),
75                )
76                .into(),
77            ),
78        ));
79        schema
80    }
81
82    fn stream_key(&self) -> Option<Vec<usize>> {
83        self.input.stream_key().map(|key| key.to_vec())
84    }
85
86    fn ctx(&self) -> OptimizerContextRef {
87        self.ctx.clone()
88    }
89}
90
91impl<PlanRef: GenericPlanRef> VectorIndexLookupJoin<PlanRef> {
92    pub fn distill<'a>(&self) -> Vec<(&'static str, Pretty<'a>)> {
93        let mut fields = vec![
94            ("top_n", Pretty::debug(&self.top_n)),
95            ("distance_type", Pretty::debug(&self.distance_type)),
96            ("index_name", Pretty::debug(&self.index_name)),
97            (
98                "vector",
99                Pretty::debug(&ExprDisplay {
100                    expr: &ExprImpl::InputRef(
101                        InputRef::new(
102                            self.vector_column_idx,
103                            self.input.schema()[self.vector_column_idx].data_type(),
104                        )
105                        .into(),
106                    ),
107                    input_schema: self.input.schema(),
108                }),
109            ),
110            (
111                "lookup_output",
112                Pretty::Array(
113                    self.info_output_indices
114                        .iter()
115                        .map(|idx| {
116                            let col = &self.info_column_desc[*idx];
117                            Pretty::debug(&(&col.name, &col.data_type))
118                        })
119                        .collect(),
120                ),
121            ),
122            ("include_distance", Pretty::debug(&self.include_distance)),
123        ];
124        if let Some(hnsw_ef_search) = self.hnsw_ef_search {
125            fields.push(("hnsw_ef_search", Pretty::debug(&hnsw_ef_search)));
126        }
127        if let Some(as_of) = &self.as_of {
128            fields.push(("as_of", Pretty::debug(&as_of)));
129        }
130        fields
131    }
132
133    pub fn to_reader_desc(&self) -> PbVectorIndexReaderDesc {
134        PbVectorIndexReaderDesc {
135            table_id: self.index_table_id.table_id,
136            info_column_desc: self
137                .info_column_desc
138                .iter()
139                .map(|col| col.to_protobuf())
140                .collect(),
141            top_n: self.top_n as _,
142            distance_type: self.distance_type as _,
143            hnsw_ef_search: self.hnsw_ef_search.unwrap_or(0) as _,
144            info_output_indices: self
145                .info_output_indices
146                .iter()
147                .map(|&idx| idx as _)
148                .collect(),
149            include_distance: self.include_distance,
150        }
151    }
152}