risingwave_frontend/optimizer/plan_node/
logical_vector_search.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use itertools::Itertools;
18use pretty_xmlish::{Pretty, XmlNode};
19use risingwave_common::array::VECTOR_DISTANCE_TYPE;
20use risingwave_common::bail;
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::types::{DataType, ScalarImpl};
23use risingwave_common::util::column_index_mapping::ColIndexMapping;
24use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
25use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
26use risingwave_pb::common::PbDistanceType;
27use risingwave_pb::plan_common::JoinType;
28
29use crate::OptimizerContextRef;
30use crate::catalog::index_catalog::VectorIndex;
31use crate::error::ErrorCode;
32use crate::expr::{
33    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
34    TableFunction, TableFunctionType, collect_input_refs,
35};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::{
38    GenericPlanNode, GenericPlanRef, TopNLimit, VectorIndexLookupJoin, ensure_sorted_required_cols,
39};
40use crate::optimizer::plan_node::utils::{Distill, childless_record};
41use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
42use crate::optimizer::property::{FunctionalDependencySet, Order};
43use crate::optimizer::rule::IndexSelectionRule;
44use crate::utils::{ColIndexMappingRewriteExt, Condition};
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47struct VectorSearchCore {
48    top_n: u64,
49    distance_type: PbDistanceType,
50    left: ExprImpl,
51    right: ExprImpl,
52    /// The indices of input that will be included in the output.
53    /// The index of distance column is `output_indices.len()`
54    output_indices: Vec<usize>,
55    input: PlanRef,
56}
57
58impl VectorSearchCore {
59    pub(crate) fn clone_with_input(&self, input: PlanRef) -> Self {
60        Self {
61            top_n: self.top_n,
62            distance_type: self.distance_type,
63            left: self.left.clone(),
64            right: self.right.clone(),
65            output_indices: self.output_indices.clone(),
66            input,
67        }
68    }
69
70    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
71        self.left = r.rewrite_expr(self.left.clone());
72        self.right = r.rewrite_expr(self.right.clone());
73    }
74
75    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
76        v.visit_expr(&self.left);
77        v.visit_expr(&self.right);
78    }
79
80    pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
81        let mut mapping = vec![None; self.input.schema().len()];
82        for (output_idx, input_idx) in self.output_indices.iter().enumerate() {
83            mapping[*input_idx] = Some(output_idx);
84        }
85        ColIndexMapping::new(mapping, self.output_indices.len() + 1)
86    }
87}
88
89impl GenericPlanNode for VectorSearchCore {
90    fn functional_dependency(&self) -> FunctionalDependencySet {
91        self.i2o_mapping()
92            .rewrite_functional_dependency_set(self.input.functional_dependency().clone())
93    }
94
95    fn schema(&self) -> Schema {
96        let fields = self
97            .output_indices
98            .iter()
99            .map(|idx| self.input.schema()[*idx].clone())
100            .chain([Field::new("vector_distance", DataType::Float64)])
101            .collect();
102        Schema { fields }
103    }
104
105    fn stream_key(&self) -> Option<Vec<usize>> {
106        self.input.stream_key().and_then(|v| {
107            let i2o_mapping = self.i2o_mapping();
108            v.iter().map(|idx| i2o_mapping.try_map(*idx)).collect()
109        })
110    }
111
112    fn ctx(&self) -> OptimizerContextRef {
113        self.input.ctx()
114    }
115}
116
117#[derive(Debug, Clone, PartialEq, Eq, Hash)]
118pub struct LogicalVectorSearch {
119    pub base: PlanBase<Logical>,
120    core: VectorSearchCore,
121}
122
123impl LogicalVectorSearch {
124    pub(crate) fn new(
125        top_n: u64,
126        distance_type: PbDistanceType,
127        left: ExprImpl,
128        right: ExprImpl,
129        input: PlanRef,
130    ) -> Self {
131        let output_indices = (0..input.schema().len()).collect();
132        let core = VectorSearchCore {
133            top_n,
134            distance_type,
135            left,
136            right,
137            output_indices,
138            input,
139        };
140        Self::with_core(core)
141    }
142
143    fn with_core(core: VectorSearchCore) -> Self {
144        let base = PlanBase::new_logical_with_core(&core);
145        Self { base, core }
146    }
147}
148
149impl_plan_tree_node_for_unary! { Logical, LogicalVectorSearch }
150
151impl PlanTreeNodeUnary<Logical> for LogicalVectorSearch {
152    fn input(&self) -> PlanRef {
153        self.core.input.clone()
154    }
155
156    fn clone_with_input(&self, input: PlanRef) -> Self {
157        let core = self.core.clone_with_input(input);
158        Self::with_core(core)
159    }
160}
161
162impl Distill for LogicalVectorSearch {
163    fn distill<'a>(&self) -> XmlNode<'a> {
164        let verbose = self.base.ctx().is_explain_verbose();
165        let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
166        vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
167        vec.push(("top_n", Pretty::debug(&self.core.top_n)));
168        vec.push(("left", Pretty::debug(&self.core.left)));
169        vec.push(("right", Pretty::debug(&self.core.right)));
170
171        if verbose {
172            vec.push((
173                "output_columns",
174                Pretty::Array(
175                    self.core
176                        .output_indices
177                        .iter()
178                        .map(|input_idx| {
179                            Pretty::debug(&self.core.input.schema().fields()[*input_idx])
180                        })
181                        .collect(),
182                ),
183            ));
184        }
185
186        childless_record("LogicalVectorSearch", vec)
187    }
188}
189
190impl ColPrunable for LogicalVectorSearch {
191    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
192        let (project_exprs, required_cols) =
193            ensure_sorted_required_cols(required_cols, self.base.schema());
194        assert!(required_cols.is_sorted());
195        let input_schema = self.core.input.schema();
196        let mut required_input_idx_bitset =
197            collect_input_refs(input_schema.len(), [&self.core.left, &self.core.right]);
198        let mut non_distance_required_input_idx = Vec::new();
199        let require_distance_col = required_cols
200            .last()
201            .map(|last_col_idx| *last_col_idx == self.core.output_indices.len())
202            .unwrap_or(false);
203        let non_distance_iter_end_idx = if require_distance_col {
204            required_cols.len() - 1
205        } else {
206            required_cols.len()
207        };
208        for &required_col_idx in &required_cols[0..non_distance_iter_end_idx] {
209            let required_input_idx = self.core.output_indices[required_col_idx];
210            non_distance_required_input_idx.push(required_input_idx);
211            required_input_idx_bitset.set(required_col_idx, true);
212        }
213        let input_required_idx = required_input_idx_bitset.ones().collect_vec();
214
215        let new_input = self.input().prune_col(&input_required_idx, ctx);
216        // mapping from idx of original input to new input
217        let mut mapping = ColIndexMapping::with_remaining_columns(
218            &input_required_idx,
219            self.input().schema().len(),
220        );
221
222        let vector_search = {
223            let mut new_core = self.core.clone_with_input(new_input);
224            new_core.left = mapping.rewrite_expr(new_core.left);
225            new_core.right = mapping.rewrite_expr(new_core.right);
226            new_core.output_indices = non_distance_required_input_idx
227                .iter()
228                .map(|input_idx| mapping.map(*input_idx))
229                .collect();
230            Self::with_core(new_core)
231        };
232        LogicalProject::create(vector_search.into(), project_exprs)
233    }
234}
235
236impl ExprRewritable<Logical> for LogicalVectorSearch {
237    fn has_rewritable_expr(&self) -> bool {
238        true
239    }
240
241    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
242        let mut core = self.core.clone();
243        core.rewrite_exprs(r);
244        Self::with_core(core).into()
245    }
246}
247
248impl ExprVisitable for LogicalVectorSearch {
249    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
250        self.core.visit_exprs(v);
251    }
252}
253
254impl PredicatePushdown for LogicalVectorSearch {
255    fn predicate_pushdown(
256        &self,
257        predicate: Condition,
258        ctx: &mut PredicatePushdownContext,
259    ) -> PlanRef {
260        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
261    }
262}
263
264impl ToStream for LogicalVectorSearch {
265    fn logical_rewrite_for_stream(
266        &self,
267        _ctx: &mut RewriteStreamContext,
268    ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
269        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
270    }
271
272    fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
273        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
274    }
275}
276
277impl LogicalVectorSearch {
278    pub fn to_top_n(
279        input: PlanRef,
280        left: ExprImpl,
281        right: ExprImpl,
282        distance_type: PbDistanceType,
283        top_n: u64,
284        output_indices: &[usize],
285    ) -> LogicalTopN {
286        let (neg, expr_type) = match distance_type {
287            PbDistanceType::Unspecified => {
288                unreachable!()
289            }
290            PbDistanceType::L1 => (false, ExprType::L1Distance),
291            PbDistanceType::L2Sqr => (false, ExprType::L2Distance),
292            PbDistanceType::Cosine => (false, ExprType::CosineDistance),
293            PbDistanceType::InnerProduct => (true, ExprType::InnerProduct),
294        };
295        let mut expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
296            expr_type,
297            vec![left, right],
298            VECTOR_DISTANCE_TYPE,
299        )));
300        if neg {
301            expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
302                ExprType::Neg,
303                vec![expr],
304                VECTOR_DISTANCE_TYPE,
305            )));
306        }
307        let exprs = generic::Project::out_col_idx_exprs(&input, output_indices.iter().copied())
308            .chain([expr])
309            .collect();
310
311        let input = LogicalProject::new(input, exprs).into();
312        let top_n = generic::TopN::without_group(
313            input,
314            TopNLimit::Simple(top_n),
315            0,
316            Order::new(vec![ColumnOrder::new(
317                output_indices.len(),
318                OrderType::ascending(),
319            )]),
320        );
321        top_n.into()
322    }
323
324    fn as_vector_table_scan(&self) -> Option<(&LogicalScan, ExprImpl, &ExprImpl)> {
325        let scan = self.core.input.as_logical_scan()?;
326        let left_const = (self.core.left.only_literal_and_func(), &self.core.left);
327        let right_const = (self.core.right.only_literal_and_func(), &self.core.right);
328        let (vector_column_expr, vector_expr) = match (left_const, right_const) {
329            ((true, _), (true, _)) => {
330                return None;
331            }
332            ((_, vector_column_expr), (true, vector_expr))
333            | ((true, vector_expr), (_, vector_column_expr)) => (vector_column_expr, vector_expr),
334            _ => return None,
335        };
336        Some((scan, vector_expr.clone(), vector_column_expr))
337    }
338
339    fn is_matched_vector_column_expr(
340        index_expr: &ExprImpl,
341        column_expr: &ExprImpl,
342        scan_output_col_idx: &[usize],
343    ) -> bool {
344        match (index_expr, column_expr) {
345            (ExprImpl::Literal(l1), ExprImpl::Literal(l2)) => l1 == l2,
346            (ExprImpl::InputRef(i1), ExprImpl::InputRef(i2)) => {
347                i1.index == scan_output_col_idx[i2.index]
348            }
349            (ExprImpl::FunctionCall(f1), ExprImpl::FunctionCall(f2)) => {
350                f1.func_type() == f2.func_type()
351                    && f1.return_type() == f2.return_type()
352                    && f1.inputs().len() == f2.inputs().len()
353                    && f1.inputs().iter().zip_eq_fast(f2.inputs()).all(|(e1, e2)| {
354                        Self::is_matched_vector_column_expr(e1, e2, scan_output_col_idx)
355                    })
356            }
357            _ => false,
358        }
359    }
360}
361
362impl LogicalVectorSearch {
363    #[expect(clippy::type_complexity)]
364    pub fn resolve_vector_index_lookup<'a>(
365        scan: &'a LogicalScan,
366        vector_column_expr: &ExprImpl,
367        distance_type: PbDistanceType,
368        output_indices: &[usize],
369    ) -> Option<(
370        &'a Arc<VectorIndex>,
371        Vec<usize>,
372        Vec<usize>,
373        Vec<(bool, usize)>,
374    )> {
375        if !scan.predicate().always_true() {
376            return None;
377        }
378        if !scan.vector_indexes().is_empty() {
379            let primary_table_cols_idx = output_indices
380                .iter()
381                .map(|input_idx| scan.output_col_idx()[*input_idx])
382                .collect_vec();
383            for index in scan.vector_indexes() {
384                if !Self::is_matched_vector_column_expr(
385                    &index.vector_expr,
386                    vector_column_expr,
387                    scan.output_col_idx(),
388                ) {
389                    continue;
390                }
391                if index.vector_index_info.distance_type() != distance_type {
392                    continue;
393                }
394
395                let mut covered_table_cols_idx = Vec::new();
396                let mut non_covered_table_cols_idx = Vec::new();
397                let mut primary_table_col_in_output =
398                    Vec::with_capacity(primary_table_cols_idx.len());
399                for table_col_idx in &primary_table_cols_idx {
400                    if let Some(covered_info_column_idx) = index
401                        .primary_to_included_info_column_mapping
402                        .get(table_col_idx)
403                    {
404                        covered_table_cols_idx.push(*table_col_idx);
405                        primary_table_col_in_output.push((true, *covered_info_column_idx));
406                    } else {
407                        primary_table_col_in_output.push((false, non_covered_table_cols_idx.len()));
408                        non_covered_table_cols_idx.push(*table_col_idx);
409                    }
410                }
411                return Some((
412                    index,
413                    covered_table_cols_idx,
414                    non_covered_table_cols_idx,
415                    primary_table_col_in_output,
416                ));
417            }
418        }
419        None
420    }
421}
422
423impl ToBatch for LogicalVectorSearch {
424    fn to_batch(&self) -> Result<BatchPlanRef> {
425        if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan()
426            && self
427                .core
428                .ctx()
429                .session_ctx()
430                .config()
431                .enable_index_selection()
432            && let Some((
433                index,
434                covered_table_cols_idx,
435                non_covered_table_cols_idx,
436                primary_table_col_in_output,
437            )) = Self::resolve_vector_index_lookup(
438                scan,
439                vector_column_expr,
440                self.core.distance_type,
441                &self.core.output_indices,
442            )
443        {
444            {
445                let vector_data_type = vector_expr.return_type();
446                let literal_vector_input = BatchValues::new(LogicalValues::new(
447                    vec![vec![vector_expr]],
448                    Schema::from_iter([Field::new("query_vector", vector_data_type)]),
449                    self.core.ctx(),
450                ))
451                .into();
452                let hnsw_ef_search =
453                    index.resolve_hnsw_ef_search(&self.core.ctx().session_ctx().config());
454                let info_column_desc = index.info_column_desc();
455                let core = VectorIndexLookupJoin {
456                    input: literal_vector_input,
457                    top_n: self.core.top_n,
458                    distance_type: self.core.distance_type,
459                    index_name: index.index_table.name.clone(),
460                    index_table_id: index.index_table.id,
461                    info_output_indices: (0..info_column_desc.len()).collect(),
462                    info_column_desc,
463                    include_distance: true,
464                    as_of: scan.as_of(),
465                    vector_column_idx: 0,
466                    hnsw_ef_search,
467                    ctx: self.core.ctx(),
468                };
469                let vector_search: BatchPlanRef = {
470                    let vector_search: BatchPlanRef = BatchVectorSearch::with_core(core).into();
471                    let unnested_array: BatchPlanRef = BatchProjectSet::new(generic::ProjectSet {
472                        select_list: vec![ExprImpl::TableFunction(
473                            TableFunction::new(
474                                TableFunctionType::Unnest,
475                                vec![ExprImpl::InputRef(
476                                    InputRef::new(1, vector_search.schema()[1].data_type()).into(),
477                                )],
478                            )?
479                            .into(),
480                        )],
481                        input: vector_search,
482                    })
483                    .into();
484                    let DataType::Struct(struct_type) = &unnested_array.schema()[1].data_type
485                    else {
486                        panic!("{:?}", unnested_array.schema()[1].data_type);
487                    };
488                    let unnest_struct = BatchProject::new(generic::Project::new(
489                        struct_type
490                            .types()
491                            .enumerate()
492                            .map(|(idx, data_type)| {
493                                ExprImpl::FunctionCall(
494                                    FunctionCall::new_unchecked(
495                                        ExprType::Field,
496                                        vec![
497                                            ExprImpl::InputRef(
498                                                InputRef::new(
499                                                    1,
500                                                    DataType::Struct(struct_type.clone()),
501                                                )
502                                                .into(),
503                                            ),
504                                            ExprImpl::Literal(
505                                                Literal::new(
506                                                    Some(ScalarImpl::Int32(idx as _)),
507                                                    DataType::Int32,
508                                                )
509                                                .into(),
510                                            ),
511                                        ],
512                                        data_type.clone(),
513                                    )
514                                    .into(),
515                                )
516                            })
517                            .collect(),
518                        unnested_array,
519                    ));
520                    unnest_struct.into()
521                };
522                let covered_output_col_idx = covered_table_cols_idx.iter().map(|table_col_idx| {
523                    index.primary_to_included_info_column_mapping[table_col_idx]
524                });
525                return Ok(if non_covered_table_cols_idx.is_empty() {
526                    BatchProject::new(generic::Project::with_out_col_idx(
527                        vector_search,
528                        covered_output_col_idx.chain([index.included_info_columns.len()]),
529                    ))
530                    .into()
531                } else {
532                    let mut primary_table_cols_idx = Vec::with_capacity(
533                        non_covered_table_cols_idx.len() + scan.table().pk().len(),
534                    );
535                    primary_table_cols_idx.extend(
536                        non_covered_table_cols_idx
537                            .iter()
538                            .cloned()
539                            .chain(scan.table().pk().iter().map(|order| order.column_index)),
540                    );
541                    let table_scan = generic::TableScan::new(
542                        primary_table_cols_idx,
543                        scan.table().clone(),
544                        vec![],
545                        vec![],
546                        self.core.input.ctx(),
547                        Condition::true_cond(),
548                        scan.as_of(),
549                    );
550                    let logical_scan = LogicalScan::from(table_scan);
551                    let batch_scan = logical_scan.to_batch()?;
552                    let vector_search_schema = vector_search.schema();
553                    let vector_search_schema_len = vector_search_schema.len();
554                    let on_condition = Condition {
555                        conjunctions: index
556                            .primary_key_idx_in_info_columns
557                            .iter()
558                            .zip_eq_debug(0..scan.table().pk().len())
559                            .map(|(pk_idx_in_info_columns, pk_idx)| {
560                                let batch_scan_pk_idx = vector_search_schema.len()
561                                    + non_covered_table_cols_idx.len()
562                                    + pk_idx;
563                                IndexSelectionRule::create_null_safe_equal_expr(
564                                    *pk_idx_in_info_columns,
565                                    vector_search_schema[*pk_idx_in_info_columns].data_type(),
566                                    batch_scan_pk_idx,
567                                    batch_scan.schema()[non_covered_table_cols_idx.len() + pk_idx]
568                                        .data_type(),
569                                )
570                            })
571                            .collect(),
572                    };
573                    let eq_predicate = EqJoinPredicate::create(
574                        vector_search_schema.len(),
575                        batch_scan.schema().len(),
576                        on_condition.clone(),
577                    );
578                    let join = generic::Join::new(
579                        vector_search,
580                        batch_scan,
581                        on_condition,
582                        JoinType::Inner,
583                        primary_table_col_in_output
584                            .iter()
585                            .map(|(covered, idx)| {
586                                if *covered {
587                                    *idx
588                                } else {
589                                    *idx + vector_search_schema_len
590                                }
591                            })
592                            // chain with distance column
593                            .chain([vector_search_schema_len - 1])
594                            .collect(),
595                    );
596                    let lookup_join = LogicalJoin::gen_batch_lookup_join(
597                        &logical_scan,
598                        eq_predicate,
599                        join,
600                        false,
601                    )?
602                    .ok_or_else(|| {
603                        ErrorCode::InternalError(
604                            "failed to convert to batch lookup join".to_owned(),
605                        )
606                    })?;
607                    lookup_join.into()
608                });
609            }
610        }
611        Self::to_top_n(
612            self.core.input.clone(),
613            self.core.left.clone(),
614            self.core.right.clone(),
615            self.core.distance_type,
616            self.core.top_n,
617            &self.core.output_indices,
618        )
619        .to_batch()
620    }
621}