risingwave_frontend/optimizer/plan_node/
logical_vector_search_lookup_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use pretty_xmlish::{Pretty, XmlNode};
18use risingwave_common::array::VECTOR_DISTANCE_TYPE;
19use risingwave_common::bail;
20use risingwave_common::catalog::{Field, Schema};
21use risingwave_common::types::{DataType, StructType};
22use risingwave_common::util::column_index_mapping::ColIndexMapping;
23use risingwave_pb::common::PbDistanceType;
24use risingwave_sqlparser::ast::AsOf;
25
26use crate::OptimizerContextRef;
27use crate::catalog::index_catalog::VectorIndex;
28use crate::expr::{ExprDisplay, ExprImpl};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::{
31    GenericPlanNode, GenericPlanRef, VectorIndexLookupJoin, ensure_sorted_required_cols,
32};
33use crate::optimizer::plan_node::utils::{Distill, childless_record};
34use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
35use crate::optimizer::property::FunctionalDependencySet;
36use crate::utils::Condition;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39struct VectorSearchLookupJoinCore {
40    top_n: u64,
41    distance_type: PbDistanceType,
42
43    input: PlanRef,
44    input_vector_col_idx: usize,
45    lookup: PlanRef,
46    lookup_vector: ExprImpl,
47
48    /// The indices of lookup that will be included in the output.
49    /// The index of distance column is `lookup_output_indices.len()`
50    lookup_output_indices: Vec<usize>,
51    include_distance: bool,
52}
53
54impl VectorSearchLookupJoinCore {
55    pub(crate) fn clone_with_input(&self, input: PlanRef, lookup: PlanRef) -> Self {
56        Self {
57            top_n: self.top_n,
58            distance_type: self.distance_type,
59            input,
60            input_vector_col_idx: self.input_vector_col_idx,
61            lookup,
62            lookup_vector: self.lookup_vector.clone(),
63            lookup_output_indices: self.lookup_output_indices.clone(),
64            include_distance: self.include_distance,
65        }
66    }
67
68    fn struct_type(&self) -> StructType {
69        StructType::row_expr_type(
70            self.lookup_output_indices
71                .iter()
72                .map(|i| {
73                    let field = &self.lookup.schema().fields[*i];
74                    field.data_type.clone()
75                })
76                .chain(self.include_distance.then_some(VECTOR_DISTANCE_TYPE)),
77        )
78    }
79}
80
81impl GenericPlanNode for VectorSearchLookupJoinCore {
82    fn functional_dependency(&self) -> FunctionalDependencySet {
83        // TODO: include dependency of array_agg column
84        FunctionalDependencySet::new(self.input.schema().len() + 1)
85    }
86
87    fn schema(&self) -> Schema {
88        let fields = self
89            .input
90            .schema()
91            .fields
92            .iter()
93            .cloned()
94            .chain([Field::new(
95                "array",
96                DataType::Struct(self.struct_type()).list(),
97            )])
98            .collect();
99
100        Schema { fields }
101    }
102
103    fn stream_key(&self) -> Option<Vec<usize>> {
104        self.input.stream_key().map(|key| key.to_vec())
105    }
106
107    fn ctx(&self) -> OptimizerContextRef {
108        self.input.ctx()
109    }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Hash)]
113pub struct LogicalVectorSearchLookupJoin {
114    pub base: PlanBase<Logical>,
115    core: VectorSearchLookupJoinCore,
116}
117
118impl LogicalVectorSearchLookupJoin {
119    pub(crate) fn new(
120        top_n: u64,
121        distance_type: PbDistanceType,
122        input: PlanRef,
123        input_vector_col_idx: usize,
124        lookup: PlanRef,
125        lookup_vector: ExprImpl,
126        lookup_output_indices: Vec<usize>,
127        include_distance: bool,
128    ) -> Self {
129        let core = VectorSearchLookupJoinCore {
130            top_n,
131            distance_type,
132            input,
133            input_vector_col_idx,
134            lookup,
135            lookup_vector,
136            lookup_output_indices,
137            include_distance,
138        };
139        Self::with_core(core)
140    }
141
142    fn with_core(core: VectorSearchLookupJoinCore) -> Self {
143        let base = PlanBase::new_logical_with_core(&core);
144        Self { base, core }
145    }
146}
147
148impl_plan_tree_node_for_binary! { Logical, LogicalVectorSearchLookupJoin }
149
150impl PlanTreeNodeBinary<Logical> for LogicalVectorSearchLookupJoin {
151    fn left(&self) -> PlanRef {
152        self.core.input.clone()
153    }
154
155    fn right(&self) -> PlanRef {
156        self.core.lookup.clone()
157    }
158
159    fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
160        let core = self.core.clone_with_input(left, right);
161        Self::with_core(core)
162    }
163}
164
165impl Distill for LogicalVectorSearchLookupJoin {
166    fn distill<'a>(&self) -> XmlNode<'a> {
167        let verbose = self.base.ctx().is_explain_verbose();
168        let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
169        vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
170        vec.push(("top_n", Pretty::debug(&self.core.top_n)));
171        vec.push((
172            "input_vector",
173            Pretty::debug(&self.core.input.schema()[self.core.input_vector_col_idx]),
174        ));
175
176        vec.push((
177            "lookup_vector",
178            Pretty::debug(&ExprDisplay {
179                expr: &self.core.lookup_vector,
180                input_schema: self.core.lookup.schema(),
181            }),
182        ));
183
184        if verbose {
185            vec.push((
186                "lookup_output_columns",
187                Pretty::Array(
188                    self.core
189                        .lookup_output_indices
190                        .iter()
191                        .map(|input_idx| {
192                            Pretty::debug(&self.core.lookup.schema().fields()[*input_idx])
193                        })
194                        .collect(),
195                ),
196            ));
197            vec.push((
198                "include_distance",
199                Pretty::debug(&self.core.include_distance),
200            ));
201        }
202
203        childless_record("LogicalVectorSearchLookupJoin", vec)
204    }
205}
206
207impl ColPrunable for LogicalVectorSearchLookupJoin {
208    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
209        let (project_exprs, mut required_cols) =
210            ensure_sorted_required_cols(required_cols, self.base.schema());
211        assert!(required_cols.is_sorted());
212        if let Some(last_col) = required_cols.last()
213            && *last_col == self.core.input.schema().len()
214        {
215            // pop the array_agg column, since we only prune base input
216            required_cols.pop();
217            let output_vector = required_cols.contains(&self.core.input_vector_col_idx);
218            if !output_vector {
219                // include vector column in the input
220                required_cols.push(self.core.input_vector_col_idx);
221            }
222
223            let new_input = self.core.input.prune_col(&required_cols, ctx);
224            let mut core = self
225                .core
226                .clone_with_input(new_input, self.core.lookup.clone());
227
228            core.input_vector_col_idx = ColIndexMapping::with_remaining_columns(
229                &required_cols,
230                self.core.input.schema().len(),
231            )
232            .map(self.core.input_vector_col_idx);
233            let vector_search = Self::with_core(core).into();
234            let input = if output_vector {
235                vector_search
236            } else {
237                // prune the vector column in the end of input, and include the array_agg column
238                LogicalProject::with_out_col_idx(
239                    vector_search,
240                    (0..required_cols.len() - 1).chain([required_cols.len()]),
241                )
242                .into()
243            };
244
245            LogicalProject::create(input, project_exprs)
246        } else {
247            // the array_agg column is pruned, no need to lookup
248            let input = self.core.input.prune_col(&required_cols, ctx);
249            LogicalProject::create(input, project_exprs)
250        }
251    }
252}
253
254impl ExprRewritable<Logical> for LogicalVectorSearchLookupJoin {}
255
256impl ExprVisitable for LogicalVectorSearchLookupJoin {}
257
258impl PredicatePushdown for LogicalVectorSearchLookupJoin {
259    fn predicate_pushdown(
260        &self,
261        predicate: Condition,
262        ctx: &mut PredicatePushdownContext,
263    ) -> PlanRef {
264        // TODO: push down to input when possible
265        let input = self
266            .core
267            .input
268            .predicate_pushdown(Condition::true_cond(), ctx);
269        let lookup = self
270            .core
271            .lookup
272            .predicate_pushdown(Condition::true_cond(), ctx);
273        let core = self.core.clone_with_input(input, lookup);
274        LogicalFilter::create(Self::with_core(core).into(), predicate)
275    }
276}
277
278impl ToStream for LogicalVectorSearchLookupJoin {
279    fn logical_rewrite_for_stream(
280        &self,
281        ctx: &mut RewriteStreamContext,
282    ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
283        if !self
284            .core
285            .input
286            .logical_rewrite_for_stream(ctx)?
287            .1
288            .is_identity()
289        {
290            // TODO: support it
291            bail!(
292                "LogicalVectorSearchLookupJoin does not support input that can possibly be rewritten"
293            )
294        }
295        Ok((
296            self.clone().into(),
297            ColIndexMapping::identity(self.base.schema().len()),
298        ))
299    }
300
301    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
302        if let Some(core) = self.to_vector_index_lookup_join(|plan| plan.to_stream(ctx))? {
303            if !matches!(&core.as_of, Some(AsOf::ProcessTime)) {
304                bail!("streaming vector index lookup join must be proctime temporal join");
305            }
306            return Ok(StreamVectorIndexLookupJoin::new(core)?.into());
307        }
308        bail!("LogicalVectorSearchLookupJoin should use proper vector index in streaming job")
309    }
310}
311
312impl LogicalVectorSearchLookupJoin {
313    pub(crate) fn as_index_lookup(&self) -> Option<(&Arc<VectorIndex>, Vec<usize>, Option<AsOf>)> {
314        if let Some(scan) = self.core.lookup.as_logical_scan()
315            && let Some((
316                index,
317                _covered_table_cols_idx,
318                non_covered_table_cols_idx,
319                primary_table_col_in_output,
320            )) = LogicalVectorSearch::resolve_vector_index_lookup(
321                scan,
322                &self.core.lookup_vector,
323                self.core.distance_type,
324                &self.core.lookup_output_indices,
325            )
326            && non_covered_table_cols_idx.is_empty()
327        {
328            let info_output_indices = primary_table_col_in_output
329                .iter()
330                .map(|(covered, idx_in_index_info_columns)| {
331                    assert!(*covered);
332                    *idx_in_index_info_columns
333                })
334                .collect();
335            Some((index, info_output_indices, scan.as_of()))
336        } else {
337            None
338        }
339    }
340}
341
342impl LogicalVectorSearchLookupJoin {
343    fn to_vector_index_lookup_join<PlanRef>(
344        &self,
345        gen_input: impl FnOnce(&LogicalPlanRef) -> Result<PlanRef>,
346    ) -> Result<Option<VectorIndexLookupJoin<PlanRef>>> {
347        if let Some((index, info_output_indices, as_of)) = self.as_index_lookup() {
348            let hnsw_ef_search =
349                index.resolve_hnsw_ef_search(&self.core.ctx().session_ctx().config());
350            let core = VectorIndexLookupJoin {
351                input: gen_input(&self.core.input)?,
352                top_n: self.core.top_n,
353                distance_type: self.core.distance_type,
354                index_name: index.index_table.name.clone(),
355                index_table_id: index.index_table.id,
356                info_column_desc: index.info_column_desc(),
357                info_output_indices,
358                include_distance: self.core.include_distance,
359                as_of,
360                vector_column_idx: self.core.input_vector_col_idx,
361                hnsw_ef_search,
362                ctx: self.core.ctx(),
363            };
364            return Ok(Some(core));
365        }
366        Ok(None)
367    }
368}
369
370impl ToBatch for LogicalVectorSearchLookupJoin {
371    fn to_batch(&self) -> Result<BatchPlanRef> {
372        if let Some(core) = self.to_vector_index_lookup_join(|plan| plan.to_batch())? {
373            return Ok(BatchVectorSearch::with_core(core).into());
374        }
375
376        bail!("no index found for BatchVectorSearchLookupJoin")
377    }
378}