risingwave_frontend/optimizer/plan_node/
logical_vector_search.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::array::VECTOR_DISTANCE_TYPE;
18use risingwave_common::bail;
19use risingwave_common::catalog::{Field, Schema};
20use risingwave_common::types::{DataType, ScalarImpl};
21use risingwave_common::util::column_index_mapping::ColIndexMapping;
22use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
23use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
24use risingwave_pb::catalog::vector_index_info;
25use risingwave_pb::common::PbDistanceType;
26use risingwave_pb::plan_common::JoinType;
27
28use crate::OptimizerContextRef;
29use crate::error::ErrorCode;
30use crate::expr::{
31    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32    TableFunction, TableFunctionType, collect_input_refs,
33};
34use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
35use crate::optimizer::plan_node::generic::{
36    GenericPlanNode, GenericPlanRef, TopNLimit, VectorIndexLookupJoin, ensure_sorted_required_cols,
37};
38use crate::optimizer::plan_node::utils::{Distill, childless_record};
39use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
40use crate::optimizer::property::{FunctionalDependencySet, Order};
41use crate::optimizer::rule::IndexSelectionRule;
42use crate::utils::{ColIndexMappingRewriteExt, Condition};
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45struct VectorSearchCore {
46    top_n: u64,
47    distance_type: PbDistanceType,
48    left: ExprImpl,
49    right: ExprImpl,
50    /// The indices of input that will be included in the output.
51    /// The index of distance column is `output_indices.len()`
52    output_indices: Vec<usize>,
53    input: PlanRef,
54}
55
56impl VectorSearchCore {
57    pub(crate) fn clone_with_input(&self, input: PlanRef) -> Self {
58        Self {
59            top_n: self.top_n,
60            distance_type: self.distance_type,
61            left: self.left.clone(),
62            right: self.right.clone(),
63            output_indices: self.output_indices.clone(),
64            input,
65        }
66    }
67
68    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
69        self.left = r.rewrite_expr(self.left.clone());
70        self.right = r.rewrite_expr(self.right.clone());
71    }
72
73    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
74        v.visit_expr(&self.left);
75        v.visit_expr(&self.right);
76    }
77
78    pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
79        let mut mapping = vec![None; self.input.schema().len()];
80        for (output_idx, input_idx) in self.output_indices.iter().enumerate() {
81            mapping[*input_idx] = Some(output_idx);
82        }
83        ColIndexMapping::new(mapping, self.output_indices.len() + 1)
84    }
85}
86
87impl GenericPlanNode for VectorSearchCore {
88    fn functional_dependency(&self) -> FunctionalDependencySet {
89        self.i2o_mapping()
90            .rewrite_functional_dependency_set(self.input.functional_dependency().clone())
91    }
92
93    fn schema(&self) -> Schema {
94        let fields = self
95            .output_indices
96            .iter()
97            .map(|idx| self.input.schema()[*idx].clone())
98            .chain([Field::new("vector_distance", DataType::Float64)])
99            .collect();
100        Schema { fields }
101    }
102
103    fn stream_key(&self) -> Option<Vec<usize>> {
104        self.input.stream_key().and_then(|v| {
105            let i2o_mapping = self.i2o_mapping();
106            v.iter().map(|idx| i2o_mapping.try_map(*idx)).collect()
107        })
108    }
109
110    fn ctx(&self) -> OptimizerContextRef {
111        self.input.ctx()
112    }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct LogicalVectorSearch {
117    pub base: PlanBase<Logical>,
118    core: VectorSearchCore,
119}
120
121impl LogicalVectorSearch {
122    pub(crate) fn new(
123        top_n: u64,
124        distance_type: PbDistanceType,
125        left: ExprImpl,
126        right: ExprImpl,
127        output_indices: Vec<usize>,
128        input: PlanRef,
129    ) -> Self {
130        let core = VectorSearchCore {
131            top_n,
132            distance_type,
133            left,
134            right,
135            output_indices,
136            input,
137        };
138        Self::with_core(core)
139    }
140
141    fn with_core(core: VectorSearchCore) -> Self {
142        let base = PlanBase::new_logical_with_core(&core);
143        Self { base, core }
144    }
145
146    pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
147        self.core.i2o_mapping()
148    }
149}
150
151impl_plan_tree_node_for_unary! { Logical, LogicalVectorSearch }
152
153impl PlanTreeNodeUnary<Logical> for LogicalVectorSearch {
154    fn input(&self) -> PlanRef {
155        self.core.input.clone()
156    }
157
158    fn clone_with_input(&self, input: PlanRef) -> Self {
159        let core = self.core.clone_with_input(input);
160        Self::with_core(core)
161    }
162}
163
164impl Distill for LogicalVectorSearch {
165    fn distill<'a>(&self) -> XmlNode<'a> {
166        let verbose = self.base.ctx().is_explain_verbose();
167        let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
168        vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
169        vec.push(("top_n", Pretty::debug(&self.core.top_n)));
170        vec.push(("left", Pretty::debug(&self.core.left)));
171        vec.push(("right", Pretty::debug(&self.core.right)));
172
173        if verbose {
174            vec.push((
175                "output_columns",
176                Pretty::Array(
177                    self.core
178                        .output_indices
179                        .iter()
180                        .map(|input_idx| {
181                            Pretty::debug(&self.core.input.schema().fields()[*input_idx])
182                        })
183                        .collect(),
184                ),
185            ));
186        }
187
188        childless_record("LogicalVectorSearch", vec)
189    }
190}
191
192impl ColPrunable for LogicalVectorSearch {
193    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
194        let (project_exprs, required_cols) =
195            ensure_sorted_required_cols(required_cols, self.base.schema());
196        assert!(required_cols.is_sorted());
197        let input_schema = self.core.input.schema();
198        let mut required_input_idx_bitset =
199            collect_input_refs(input_schema.len(), [&self.core.left, &self.core.right]);
200        let mut non_distance_required_input_idx = Vec::new();
201        let require_distance_col = required_cols
202            .last()
203            .map(|last_col_idx| *last_col_idx == self.core.output_indices.len())
204            .unwrap_or(false);
205        let non_distance_iter_end_idx = if require_distance_col {
206            required_cols.len() - 1
207        } else {
208            required_cols.len()
209        };
210        for &required_col_idx in &required_cols[0..non_distance_iter_end_idx] {
211            let required_input_idx = self.core.output_indices[required_col_idx];
212            non_distance_required_input_idx.push(required_input_idx);
213            required_input_idx_bitset.set(required_col_idx, true);
214        }
215        let input_required_idx = required_input_idx_bitset.ones().collect_vec();
216
217        let new_input = self.input().prune_col(&input_required_idx, ctx);
218        // mapping from idx of original input to new input
219        let mut mapping = ColIndexMapping::with_remaining_columns(
220            &input_required_idx,
221            self.input().schema().len(),
222        );
223
224        let vector_search = {
225            let mut new_core = self.core.clone_with_input(new_input);
226            new_core.left = mapping.rewrite_expr(new_core.left);
227            new_core.right = mapping.rewrite_expr(new_core.right);
228            new_core.output_indices = non_distance_required_input_idx
229                .iter()
230                .map(|input_idx| mapping.map(*input_idx))
231                .collect();
232            Self::with_core(new_core)
233        };
234        LogicalProject::create(vector_search.into(), project_exprs)
235    }
236}
237
238impl ExprRewritable<Logical> for LogicalVectorSearch {
239    fn has_rewritable_expr(&self) -> bool {
240        true
241    }
242
243    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
244        let mut core = self.core.clone();
245        core.rewrite_exprs(r);
246        Self::with_core(core).into()
247    }
248}
249
250impl ExprVisitable for LogicalVectorSearch {
251    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
252        self.core.visit_exprs(v);
253    }
254}
255
256impl PredicatePushdown for LogicalVectorSearch {
257    fn predicate_pushdown(
258        &self,
259        predicate: Condition,
260        ctx: &mut PredicatePushdownContext,
261    ) -> PlanRef {
262        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
263    }
264}
265
266impl ToStream for LogicalVectorSearch {
267    fn logical_rewrite_for_stream(
268        &self,
269        _ctx: &mut RewriteStreamContext,
270    ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
271        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
272    }
273
274    fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
275        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
276    }
277}
278
279impl LogicalVectorSearch {
280    fn to_top_n(&self) -> LogicalTopN {
281        let (neg, expr_type) = match self.core.distance_type {
282            PbDistanceType::Unspecified => {
283                unreachable!()
284            }
285            PbDistanceType::L1 => (false, ExprType::L1Distance),
286            PbDistanceType::L2Sqr => (false, ExprType::L2Distance),
287            PbDistanceType::Cosine => (false, ExprType::CosineDistance),
288            PbDistanceType::InnerProduct => (true, ExprType::InnerProduct),
289        };
290        let mut expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
291            expr_type,
292            vec![self.core.left.clone(), self.core.right.clone()],
293            VECTOR_DISTANCE_TYPE,
294        )));
295        if neg {
296            expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
297                ExprType::Neg,
298                vec![expr],
299                VECTOR_DISTANCE_TYPE,
300            )));
301        }
302        let exprs = generic::Project::out_col_idx_exprs(
303            &self.core.input,
304            self.core.output_indices.iter().copied(),
305        )
306        .chain([expr])
307        .collect();
308
309        let input = LogicalProject::new(self.input(), exprs).into();
310        let top_n = generic::TopN::without_group(
311            input,
312            TopNLimit::Simple(self.core.top_n),
313            0,
314            Order::new(vec![ColumnOrder::new(
315                self.core.output_indices.len(),
316                OrderType::ascending(),
317            )]),
318        );
319        top_n.into()
320    }
321
322    fn as_vector_table_scan(&self) -> Option<(&LogicalScan, ExprImpl, &ExprImpl)> {
323        let scan = self.core.input.as_logical_scan()?;
324        if !scan.predicate().always_true() {
325            return None;
326        }
327        let left_const = (self.core.left.only_literal_and_func(), &self.core.left);
328        let right_const = (self.core.right.only_literal_and_func(), &self.core.right);
329        let (vector_column_expr, vector_expr) = match (left_const, right_const) {
330            ((true, _), (true, _)) => {
331                return None;
332            }
333            ((_, vector_column_expr), (true, vector_expr))
334            | ((true, vector_expr), (_, vector_column_expr)) => (vector_column_expr, vector_expr),
335            _ => return None,
336        };
337        Some((scan, vector_expr.clone(), vector_column_expr))
338    }
339
340    fn is_matched_vector_column_expr(
341        index_expr: &ExprImpl,
342        column_expr: &ExprImpl,
343        scan_output_col_idx: &[usize],
344    ) -> bool {
345        match (index_expr, column_expr) {
346            (ExprImpl::Literal(l1), ExprImpl::Literal(l2)) => l1 == l2,
347            (ExprImpl::InputRef(i1), ExprImpl::InputRef(i2)) => {
348                i1.index == scan_output_col_idx[i2.index]
349            }
350            (ExprImpl::FunctionCall(f1), ExprImpl::FunctionCall(f2)) => {
351                f1.func_type() == f2.func_type()
352                    && f1.return_type() == f2.return_type()
353                    && f1.inputs().len() == f2.inputs().len()
354                    && f1.inputs().iter().zip_eq_fast(f2.inputs()).all(|(e1, e2)| {
355                        Self::is_matched_vector_column_expr(e1, e2, scan_output_col_idx)
356                    })
357            }
358            _ => false,
359        }
360    }
361}
362
363impl ToBatch for LogicalVectorSearch {
364    fn to_batch(&self) -> crate::error::Result<BatchPlanRef> {
365        if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan()
366            && !scan.vector_indexes().is_empty()
367            && self
368                .core
369                .ctx()
370                .session_ctx()
371                .config()
372                .enable_index_selection()
373        {
374            for index in scan.vector_indexes() {
375                if !Self::is_matched_vector_column_expr(
376                    &index.vector_expr,
377                    vector_column_expr,
378                    scan.output_col_idx(),
379                ) {
380                    continue;
381                }
382                if index.vector_index_info.distance_type() != self.core.distance_type {
383                    continue;
384                }
385
386                let primary_table_cols_idx = self
387                    .core
388                    .output_indices
389                    .iter()
390                    .map(|input_idx| scan.output_col_idx()[*input_idx])
391                    .collect_vec();
392                let mut covered_table_cols_idx = Vec::new();
393                let mut non_covered_table_cols_idx = Vec::new();
394                let mut primary_table_col_in_output =
395                    Vec::with_capacity(primary_table_cols_idx.len());
396                for table_col_idx in &primary_table_cols_idx {
397                    if let Some(covered_info_column_idx) = index
398                        .primary_to_included_info_column_mapping
399                        .get(table_col_idx)
400                    {
401                        covered_table_cols_idx.push(*table_col_idx);
402                        primary_table_col_in_output.push((true, *covered_info_column_idx));
403                    } else {
404                        primary_table_col_in_output.push((false, non_covered_table_cols_idx.len()));
405                        non_covered_table_cols_idx.push(*table_col_idx);
406                    }
407                }
408                let vector_data_type = vector_expr.return_type();
409                let literal_vector_input = BatchValues::new(LogicalValues::new(
410                    vec![vec![vector_expr]],
411                    Schema::from_iter([Field::new("query_vector", vector_data_type)]),
412                    self.core.ctx(),
413                ))
414                .into();
415                let hnsw_ef_search = match index.vector_index_info.config.as_ref().unwrap() {
416                    vector_index_info::Config::Flat(_) => None,
417                    vector_index_info::Config::HnswFlat(_) => Some(
418                        self.core
419                            .ctx()
420                            .session_ctx()
421                            .config()
422                            .batch_hnsw_ef_search(),
423                    ),
424                };
425                let info_column_desc = index.info_column_desc();
426                let core = VectorIndexLookupJoin {
427                    input: literal_vector_input,
428                    top_n: self.core.top_n,
429                    distance_type: self.core.distance_type,
430                    index_name: index.index_table.name.clone(),
431                    index_table_id: index.index_table.id,
432                    info_output_indices: (0..info_column_desc.len()).collect(),
433                    info_column_desc,
434                    include_distance: true,
435                    as_of: scan.as_of(),
436                    vector_column_idx: 0,
437                    hnsw_ef_search,
438                    ctx: self.core.ctx(),
439                };
440                let vector_search: BatchPlanRef = {
441                    let vector_search: BatchPlanRef = BatchVectorSearch::with_core(core).into();
442                    let unnested_array: BatchPlanRef = BatchProjectSet::new(generic::ProjectSet {
443                        select_list: vec![ExprImpl::TableFunction(
444                            TableFunction::new(
445                                TableFunctionType::Unnest,
446                                vec![ExprImpl::InputRef(
447                                    InputRef::new(1, vector_search.schema()[1].data_type()).into(),
448                                )],
449                            )?
450                            .into(),
451                        )],
452                        input: vector_search,
453                    })
454                    .into();
455                    let DataType::Struct(struct_type) = &unnested_array.schema()[1].data_type
456                    else {
457                        panic!("{:?}", unnested_array.schema()[1].data_type);
458                    };
459                    let unnest_struct = BatchProject::new(generic::Project::new(
460                        struct_type
461                            .types()
462                            .enumerate()
463                            .map(|(idx, data_type)| {
464                                ExprImpl::FunctionCall(
465                                    FunctionCall::new_unchecked(
466                                        ExprType::Field,
467                                        vec![
468                                            ExprImpl::InputRef(
469                                                InputRef::new(
470                                                    1,
471                                                    DataType::Struct(struct_type.clone()),
472                                                )
473                                                .into(),
474                                            ),
475                                            ExprImpl::Literal(
476                                                Literal::new(
477                                                    Some(ScalarImpl::Int32(idx as _)),
478                                                    DataType::Int32,
479                                                )
480                                                .into(),
481                                            ),
482                                        ],
483                                        data_type.clone(),
484                                    )
485                                    .into(),
486                                )
487                            })
488                            .collect(),
489                        unnested_array,
490                    ));
491                    unnest_struct.into()
492                };
493                let covered_output_col_idx = covered_table_cols_idx.iter().map(|table_col_idx| {
494                    index.primary_to_included_info_column_mapping[table_col_idx]
495                });
496                return Ok(if non_covered_table_cols_idx.is_empty() {
497                    BatchProject::new(generic::Project::with_out_col_idx(
498                        vector_search,
499                        covered_output_col_idx.chain([index.included_info_columns.len()]),
500                    ))
501                    .into()
502                } else {
503                    let mut primary_table_cols_idx = Vec::with_capacity(
504                        non_covered_table_cols_idx.len() + scan.table().pk().len(),
505                    );
506                    primary_table_cols_idx.extend(
507                        non_covered_table_cols_idx
508                            .iter()
509                            .cloned()
510                            .chain(scan.table().pk().iter().map(|order| order.column_index)),
511                    );
512                    let table_scan = generic::TableScan::new(
513                        primary_table_cols_idx,
514                        scan.table().clone(),
515                        vec![],
516                        vec![],
517                        self.core.input.ctx(),
518                        Condition::true_cond(),
519                        scan.as_of(),
520                    );
521                    let logical_scan = LogicalScan::from(table_scan);
522                    let batch_scan = logical_scan.to_batch()?;
523                    let vector_search_schema = vector_search.schema();
524                    let vector_search_schema_len = vector_search_schema.len();
525                    let on_condition = Condition {
526                        conjunctions: index
527                            .primary_key_idx_in_info_columns
528                            .iter()
529                            .zip_eq_debug(0..scan.table().pk().len())
530                            .map(|(pk_idx_in_info_columns, pk_idx)| {
531                                let batch_scan_pk_idx = vector_search_schema.len()
532                                    + non_covered_table_cols_idx.len()
533                                    + pk_idx;
534                                IndexSelectionRule::create_null_safe_equal_expr(
535                                    *pk_idx_in_info_columns,
536                                    vector_search_schema[*pk_idx_in_info_columns].data_type(),
537                                    batch_scan_pk_idx,
538                                    batch_scan.schema()[non_covered_table_cols_idx.len() + pk_idx]
539                                        .data_type(),
540                                )
541                            })
542                            .collect(),
543                    };
544                    let eq_predicate = EqJoinPredicate::create(
545                        vector_search_schema.len(),
546                        batch_scan.schema().len(),
547                        on_condition.clone(),
548                    );
549                    let join = generic::Join::new(
550                        vector_search,
551                        batch_scan,
552                        on_condition,
553                        JoinType::Inner,
554                        primary_table_col_in_output
555                            .iter()
556                            .map(|(covered, idx)| {
557                                if *covered {
558                                    *idx
559                                } else {
560                                    *idx + vector_search_schema_len
561                                }
562                            })
563                            // chain with distance column
564                            .chain([vector_search_schema_len - 1])
565                            .collect(),
566                    );
567                    let lookup_join = LogicalJoin::gen_batch_lookup_join(
568                        &logical_scan,
569                        eq_predicate,
570                        join,
571                        false,
572                    )?
573                    .ok_or_else(|| {
574                        ErrorCode::InternalError(
575                            "failed to convert to batch lookup join".to_owned(),
576                        )
577                    })?;
578                    lookup_join.into()
579                });
580            }
581        }
582        self.to_top_n().to_batch()
583    }
584}