risingwave_frontend/optimizer/plan_node/
logical_vector_search_lookup_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use pretty_xmlish::{Pretty, XmlNode};
18use risingwave_common::array::VECTOR_DISTANCE_TYPE;
19use risingwave_common::bail;
20use risingwave_common::catalog::{Field, Schema};
21use risingwave_common::types::{DataType, StructType};
22use risingwave_common::util::column_index_mapping::ColIndexMapping;
23use risingwave_pb::common::PbDistanceType;
24use risingwave_sqlparser::ast::AsOf;
25
26use crate::OptimizerContextRef;
27use crate::catalog::index_catalog::VectorIndex;
28use crate::expr::{ExprDisplay, ExprImpl};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::{
31    GenericPlanNode, GenericPlanRef, VectorIndexLookupJoin, ensure_sorted_required_cols,
32};
33use crate::optimizer::plan_node::utils::{Distill, childless_record};
34use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
35use crate::optimizer::property::FunctionalDependencySet;
36use crate::utils::Condition;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39struct VectorSearchLookupJoinCore {
40    top_n: u64,
41    distance_type: PbDistanceType,
42
43    input: PlanRef,
44    input_vector_col_idx: usize,
45    lookup: PlanRef,
46    lookup_vector: ExprImpl,
47
48    /// The indices of lookup that will be included in the output.
49    /// The index of distance column is `lookup_output_indices.len()`
50    lookup_output_indices: Vec<usize>,
51    include_distance: bool,
52}
53
54impl VectorSearchLookupJoinCore {
55    pub(crate) fn clone_with_input(&self, input: PlanRef, lookup: PlanRef) -> Self {
56        Self {
57            top_n: self.top_n,
58            distance_type: self.distance_type,
59            input,
60            input_vector_col_idx: self.input_vector_col_idx,
61            lookup,
62            lookup_vector: self.lookup_vector.clone(),
63            lookup_output_indices: self.lookup_output_indices.clone(),
64            include_distance: self.include_distance,
65        }
66    }
67
68    fn struct_type(&self) -> StructType {
69        StructType::new(
70            self.lookup_output_indices
71                .iter()
72                .map(|i| {
73                    let field = &self.lookup.schema().fields[*i];
74                    (field.name.clone(), field.data_type.clone())
75                })
76                .chain(
77                    self.include_distance
78                        .then(|| ("vector_distance".to_owned(), VECTOR_DISTANCE_TYPE)),
79                ),
80        )
81    }
82}
83
84impl GenericPlanNode for VectorSearchLookupJoinCore {
85    fn functional_dependency(&self) -> FunctionalDependencySet {
86        // TODO: include dependency of array_agg column
87        FunctionalDependencySet::new(self.input.schema().len() + 1)
88    }
89
90    fn schema(&self) -> Schema {
91        let fields = self
92            .input
93            .schema()
94            .fields
95            .iter()
96            .cloned()
97            .chain([Field::new(
98                "array",
99                DataType::Struct(self.struct_type()).list(),
100            )])
101            .collect();
102
103        Schema { fields }
104    }
105
106    fn stream_key(&self) -> Option<Vec<usize>> {
107        self.input.stream_key().map(|key| key.to_vec())
108    }
109
110    fn ctx(&self) -> OptimizerContextRef {
111        self.input.ctx()
112    }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct LogicalVectorSearchLookupJoin {
117    pub base: PlanBase<Logical>,
118    core: VectorSearchLookupJoinCore,
119}
120
121impl LogicalVectorSearchLookupJoin {
122    pub(crate) fn new(
123        top_n: u64,
124        distance_type: PbDistanceType,
125        input: PlanRef,
126        input_vector_col_idx: usize,
127        lookup: PlanRef,
128        lookup_vector: ExprImpl,
129        lookup_output_indices: Vec<usize>,
130        include_distance: bool,
131    ) -> Self {
132        let core = VectorSearchLookupJoinCore {
133            top_n,
134            distance_type,
135            input,
136            input_vector_col_idx,
137            lookup,
138            lookup_vector,
139            lookup_output_indices,
140            include_distance,
141        };
142        Self::with_core(core)
143    }
144
145    fn with_core(core: VectorSearchLookupJoinCore) -> Self {
146        let base = PlanBase::new_logical_with_core(&core);
147        Self { base, core }
148    }
149}
150
151impl_plan_tree_node_for_binary! { Logical, LogicalVectorSearchLookupJoin }
152
153impl PlanTreeNodeBinary<Logical> for LogicalVectorSearchLookupJoin {
154    fn left(&self) -> PlanRef {
155        self.core.input.clone()
156    }
157
158    fn right(&self) -> PlanRef {
159        self.core.lookup.clone()
160    }
161
162    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
163        let core = self.core.clone_with_input(left, right);
164        Self::with_core(core)
165    }
166}
167
168impl Distill for LogicalVectorSearchLookupJoin {
169    fn distill<'a>(&self) -> XmlNode<'a> {
170        let verbose = self.base.ctx().is_explain_verbose();
171        let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
172        vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
173        vec.push(("top_n", Pretty::debug(&self.core.top_n)));
174        vec.push((
175            "input_vector",
176            Pretty::debug(&self.core.input.schema()[self.core.input_vector_col_idx]),
177        ));
178
179        vec.push((
180            "lookup_vector",
181            Pretty::debug(&ExprDisplay {
182                expr: &self.core.lookup_vector,
183                input_schema: self.core.lookup.schema(),
184            }),
185        ));
186
187        if verbose {
188            vec.push((
189                "lookup_output_columns",
190                Pretty::Array(
191                    self.core
192                        .lookup_output_indices
193                        .iter()
194                        .map(|input_idx| {
195                            Pretty::debug(&self.core.lookup.schema().fields()[*input_idx])
196                        })
197                        .collect(),
198                ),
199            ));
200            vec.push((
201                "include_distance",
202                Pretty::debug(&self.core.include_distance),
203            ));
204        }
205
206        childless_record("LogicalVectorSearchLookupJoin", vec)
207    }
208}
209
210impl ColPrunable for LogicalVectorSearchLookupJoin {
211    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
212        let (project_exprs, mut required_cols) =
213            ensure_sorted_required_cols(required_cols, self.base.schema());
214        assert!(required_cols.is_sorted());
215        if let Some(last_col) = required_cols.last()
216            && *last_col == self.core.input.schema().len()
217        {
218            // pop the array_agg column, since we only prune base input
219            required_cols.pop();
220            let output_vector = required_cols.contains(&self.core.input_vector_col_idx);
221            if !output_vector {
222                // include vector column in the input
223                required_cols.push(self.core.input_vector_col_idx);
224            }
225
226            let new_input = self.core.input.prune_col(&required_cols, ctx);
227            let mut core = self
228                .core
229                .clone_with_input(new_input, self.core.lookup.clone());
230
231            core.input_vector_col_idx = ColIndexMapping::with_remaining_columns(
232                &required_cols,
233                self.core.input.schema().len(),
234            )
235            .map(self.core.input_vector_col_idx);
236            let vector_search = Self::with_core(core).into();
237            let input = if output_vector {
238                vector_search
239            } else {
240                // prune the vector column in the end of input, and include the array_agg column
241                LogicalProject::with_out_col_idx(
242                    vector_search,
243                    (0..required_cols.len() - 1).chain([required_cols.len()]),
244                )
245                .into()
246            };
247
248            LogicalProject::create(input, project_exprs)
249        } else {
250            // the array_agg column is pruned, no need to lookup
251            let input = self.core.input.prune_col(&required_cols, ctx);
252            LogicalProject::create(input, project_exprs)
253        }
254    }
255}
256
257impl ExprRewritable<Logical> for LogicalVectorSearchLookupJoin {}
258
259impl ExprVisitable for LogicalVectorSearchLookupJoin {}
260
261impl PredicatePushdown for LogicalVectorSearchLookupJoin {
262    fn predicate_pushdown(
263        &self,
264        predicate: Condition,
265        ctx: &mut PredicatePushdownContext,
266    ) -> PlanRef {
267        // TODO: push down to input when possible
268        let input = self
269            .core
270            .input
271            .predicate_pushdown(Condition::true_cond(), ctx);
272        let lookup = self
273            .core
274            .lookup
275            .predicate_pushdown(Condition::true_cond(), ctx);
276        let core = self.core.clone_with_input(input, lookup);
277        LogicalFilter::create(Self::with_core(core).into(), predicate)
278    }
279}
280
281impl ToStream for LogicalVectorSearchLookupJoin {
282    fn logical_rewrite_for_stream(
283        &self,
284        _ctx: &mut RewriteStreamContext,
285    ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
286        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
287    }
288
289    fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
290        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
291    }
292}
293
294impl LogicalVectorSearchLookupJoin {
295    pub(crate) fn as_index_lookup(&self) -> Option<(&Arc<VectorIndex>, Vec<usize>, Option<AsOf>)> {
296        if let Some(scan) = self.core.lookup.as_logical_scan()
297            && let Some((
298                index,
299                _covered_table_cols_idx,
300                non_covered_table_cols_idx,
301                primary_table_col_in_output,
302            )) = LogicalVectorSearch::resolve_vector_index_lookup(
303                scan,
304                &self.core.lookup_vector,
305                self.core.distance_type,
306                &self.core.lookup_output_indices,
307            )
308            && non_covered_table_cols_idx.is_empty()
309        {
310            let info_output_indices = primary_table_col_in_output
311                .iter()
312                .map(|(covered, idx_in_index_info_columns)| {
313                    assert!(*covered);
314                    *idx_in_index_info_columns
315                })
316                .collect();
317            Some((index, info_output_indices, scan.as_of()))
318        } else {
319            None
320        }
321    }
322}
323
324impl ToBatch for LogicalVectorSearchLookupJoin {
325    fn to_batch(&self) -> Result<BatchPlanRef> {
326        if let Some((index, info_output_indices, as_of)) = self.as_index_lookup() {
327            let hnsw_ef_search =
328                index.resolve_hnsw_ef_search(&self.core.ctx().session_ctx().config());
329            let core = VectorIndexLookupJoin {
330                input: self.core.input.to_batch()?,
331                top_n: self.core.top_n,
332                distance_type: self.core.distance_type,
333                index_name: index.index_table.name.clone(),
334                index_table_id: index.index_table.id,
335                info_column_desc: index.info_column_desc(),
336                info_output_indices,
337                include_distance: self.core.include_distance,
338                as_of,
339                vector_column_idx: self.core.input_vector_col_idx,
340                hnsw_ef_search,
341                ctx: self.core.ctx(),
342            };
343            return Ok(BatchVectorSearch::with_core(core).into());
344        }
345
346        bail!("no index found for BatchVectorSearchLookupJoin")
347    }
348}