risingwave_frontend/optimizer/plan_node/
logical_vector_search.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::array::VECTOR_DISTANCE_TYPE;
18use risingwave_common::bail;
19use risingwave_common::catalog::{Field, Schema};
20use risingwave_common::types::{DataType, ScalarImpl};
21use risingwave_common::util::column_index_mapping::ColIndexMapping;
22use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
23use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
24use risingwave_pb::catalog::vector_index_info;
25use risingwave_pb::common::PbDistanceType;
26use risingwave_pb::plan_common::JoinType;
27
28use crate::OptimizerContextRef;
29use crate::error::ErrorCode;
30use crate::expr::{
31    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32    TableFunction, TableFunctionType, collect_input_refs,
33};
34use crate::optimizer::plan_node::batch_vector_search::BatchVectorSearchCore;
35use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
36use crate::optimizer::plan_node::generic::{
37    GenericPlanNode, GenericPlanRef, TopNLimit, ensure_sorted_required_cols,
38};
39use crate::optimizer::plan_node::utils::{Distill, childless_record};
40use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
41use crate::optimizer::property::{FunctionalDependencySet, Order};
42use crate::optimizer::rule::IndexSelectionRule;
43use crate::utils::{ColIndexMappingRewriteExt, Condition};
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46struct VectorSearchCore {
47    top_n: u64,
48    distance_type: PbDistanceType,
49    left: ExprImpl,
50    right: ExprImpl,
51    /// The indices of input that will be included in the output.
52    /// The index of distance column is `output_indices.len()`
53    output_indices: Vec<usize>,
54    input: PlanRef,
55}
56
57impl VectorSearchCore {
58    pub(crate) fn clone_with_input(&self, input: PlanRef) -> Self {
59        Self {
60            top_n: self.top_n,
61            distance_type: self.distance_type,
62            left: self.left.clone(),
63            right: self.right.clone(),
64            output_indices: self.output_indices.clone(),
65            input,
66        }
67    }
68
69    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
70        self.left = r.rewrite_expr(self.left.clone());
71        self.right = r.rewrite_expr(self.right.clone());
72    }
73
74    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
75        v.visit_expr(&self.left);
76        v.visit_expr(&self.right);
77    }
78
79    pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
80        let mut mapping = vec![None; self.input.schema().len()];
81        for (output_idx, input_idx) in self.output_indices.iter().enumerate() {
82            mapping[*input_idx] = Some(output_idx);
83        }
84        ColIndexMapping::new(mapping, self.output_indices.len() + 1)
85    }
86}
87
88impl GenericPlanNode for VectorSearchCore {
89    fn functional_dependency(&self) -> FunctionalDependencySet {
90        self.i2o_mapping()
91            .rewrite_functional_dependency_set(self.input.functional_dependency().clone())
92    }
93
94    fn schema(&self) -> Schema {
95        let fields = self
96            .output_indices
97            .iter()
98            .map(|idx| self.input.schema()[*idx].clone())
99            .chain([Field::new("vector_distance", DataType::Float64)])
100            .collect();
101        Schema { fields }
102    }
103
104    fn stream_key(&self) -> Option<Vec<usize>> {
105        self.input.stream_key().and_then(|v| {
106            let i2o_mapping = self.i2o_mapping();
107            v.iter().map(|idx| i2o_mapping.try_map(*idx)).collect()
108        })
109    }
110
111    fn ctx(&self) -> OptimizerContextRef {
112        self.input.ctx()
113    }
114}
115
116#[derive(Debug, Clone, PartialEq, Eq, Hash)]
117pub struct LogicalVectorSearch {
118    pub base: PlanBase<Logical>,
119    core: VectorSearchCore,
120}
121
122impl LogicalVectorSearch {
123    pub(crate) fn new(
124        top_n: u64,
125        distance_type: PbDistanceType,
126        left: ExprImpl,
127        right: ExprImpl,
128        output_indices: Vec<usize>,
129        input: PlanRef,
130    ) -> Self {
131        let core = VectorSearchCore {
132            top_n,
133            distance_type,
134            left,
135            right,
136            output_indices,
137            input,
138        };
139        Self::with_core(core)
140    }
141
142    fn with_core(core: VectorSearchCore) -> Self {
143        let base = PlanBase::new_logical_with_core(&core);
144        Self { base, core }
145    }
146
147    pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
148        self.core.i2o_mapping()
149    }
150}
151
152impl_plan_tree_node_for_unary! { Logical, LogicalVectorSearch }
153
154impl PlanTreeNodeUnary<Logical> for LogicalVectorSearch {
155    fn input(&self) -> PlanRef {
156        self.core.input.clone()
157    }
158
159    fn clone_with_input(&self, input: PlanRef) -> Self {
160        let core = self.core.clone_with_input(input);
161        Self::with_core(core)
162    }
163}
164
165impl Distill for LogicalVectorSearch {
166    fn distill<'a>(&self) -> XmlNode<'a> {
167        let verbose = self.base.ctx().is_explain_verbose();
168        let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
169        vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
170        vec.push(("top_n", Pretty::debug(&self.core.top_n)));
171        vec.push(("left", Pretty::debug(&self.core.left)));
172        vec.push(("right", Pretty::debug(&self.core.right)));
173
174        if verbose {
175            vec.push((
176                "output_columns",
177                Pretty::Array(
178                    self.core
179                        .output_indices
180                        .iter()
181                        .map(|input_idx| {
182                            Pretty::debug(&self.core.input.schema().fields()[*input_idx])
183                        })
184                        .collect(),
185                ),
186            ));
187        }
188
189        childless_record("LogicalVectorSearch", vec)
190    }
191}
192
193impl ColPrunable for LogicalVectorSearch {
194    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
195        let (project_exprs, required_cols) =
196            ensure_sorted_required_cols(required_cols, self.base.schema());
197        assert!(required_cols.is_sorted());
198        let input_schema = self.core.input.schema();
199        let mut required_input_idx_bitset =
200            collect_input_refs(input_schema.len(), [&self.core.left, &self.core.right]);
201        let mut non_distance_required_input_idx = Vec::new();
202        let require_distance_col = required_cols
203            .last()
204            .map(|last_col_idx| *last_col_idx == self.core.output_indices.len())
205            .unwrap_or(false);
206        let non_distance_iter_end_idx = if require_distance_col {
207            required_cols.len() - 1
208        } else {
209            required_cols.len()
210        };
211        for &required_col_idx in &required_cols[0..non_distance_iter_end_idx] {
212            let required_input_idx = self.core.output_indices[required_col_idx];
213            non_distance_required_input_idx.push(required_input_idx);
214            required_input_idx_bitset.set(required_col_idx, true);
215        }
216        let input_required_idx = required_input_idx_bitset.ones().collect_vec();
217
218        let new_input = self.input().prune_col(&input_required_idx, ctx);
219        // mapping from idx of original input to new input
220        let mut mapping = ColIndexMapping::with_remaining_columns(
221            &input_required_idx,
222            self.input().schema().len(),
223        );
224
225        let vector_search = {
226            let mut new_core = self.core.clone_with_input(new_input);
227            new_core.left = mapping.rewrite_expr(new_core.left);
228            new_core.right = mapping.rewrite_expr(new_core.right);
229            new_core.output_indices = non_distance_required_input_idx
230                .iter()
231                .map(|input_idx| mapping.map(*input_idx))
232                .collect();
233            Self::with_core(new_core)
234        };
235        LogicalProject::create(vector_search.into(), project_exprs)
236    }
237}
238
239impl ExprRewritable<Logical> for LogicalVectorSearch {
240    fn has_rewritable_expr(&self) -> bool {
241        true
242    }
243
244    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
245        let mut core = self.core.clone();
246        core.rewrite_exprs(r);
247        Self::with_core(core).into()
248    }
249}
250
251impl ExprVisitable for LogicalVectorSearch {
252    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
253        self.core.visit_exprs(v);
254    }
255}
256
257impl PredicatePushdown for LogicalVectorSearch {
258    fn predicate_pushdown(
259        &self,
260        predicate: Condition,
261        ctx: &mut PredicatePushdownContext,
262    ) -> PlanRef {
263        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
264    }
265}
266
267impl ToStream for LogicalVectorSearch {
268    fn logical_rewrite_for_stream(
269        &self,
270        _ctx: &mut RewriteStreamContext,
271    ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
272        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
273    }
274
275    fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
276        bail!("LogicalVectorSearch can only for batch plan, not stream plan");
277    }
278}
279
280impl LogicalVectorSearch {
281    fn to_top_n(&self) -> LogicalTopN {
282        let (neg, expr_type) = match self.core.distance_type {
283            PbDistanceType::Unspecified => {
284                unreachable!()
285            }
286            PbDistanceType::L1 => (false, ExprType::L1Distance),
287            PbDistanceType::L2Sqr => (false, ExprType::L2Distance),
288            PbDistanceType::Cosine => (false, ExprType::CosineDistance),
289            PbDistanceType::InnerProduct => (true, ExprType::InnerProduct),
290        };
291        let mut expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
292            expr_type,
293            vec![self.core.left.clone(), self.core.right.clone()],
294            VECTOR_DISTANCE_TYPE,
295        )));
296        if neg {
297            expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
298                ExprType::Neg,
299                vec![expr],
300                VECTOR_DISTANCE_TYPE,
301            )));
302        }
303        let exprs = generic::Project::out_col_idx_exprs(
304            &self.core.input,
305            self.core.output_indices.iter().copied(),
306        )
307        .chain([expr])
308        .collect();
309
310        let input = LogicalProject::new(self.input(), exprs).into();
311        let top_n = generic::TopN::without_group(
312            input,
313            TopNLimit::Simple(self.core.top_n),
314            0,
315            Order::new(vec![ColumnOrder::new(
316                self.core.output_indices.len(),
317                OrderType::ascending(),
318            )]),
319        );
320        top_n.into()
321    }
322
323    fn as_vector_table_scan(&self) -> Option<(&LogicalScan, ExprImpl, &ExprImpl)> {
324        let scan = self.core.input.as_logical_scan()?;
325        if !scan.predicate().always_true() {
326            return None;
327        }
328        let left_const = (self.core.left.only_literal_and_func(), &self.core.left);
329        let right_const = (self.core.right.only_literal_and_func(), &self.core.right);
330        let (vector_column_expr, vector_expr) = match (left_const, right_const) {
331            ((true, _), (true, _)) => {
332                return None;
333            }
334            ((_, vector_column_expr), (true, vector_expr))
335            | ((true, vector_expr), (_, vector_column_expr)) => (vector_column_expr, vector_expr),
336            _ => return None,
337        };
338        Some((scan, vector_expr.clone(), vector_column_expr))
339    }
340
341    fn is_matched_vector_column_expr(
342        index_expr: &ExprImpl,
343        column_expr: &ExprImpl,
344        scan_output_col_idx: &[usize],
345    ) -> bool {
346        match (index_expr, column_expr) {
347            (ExprImpl::Literal(l1), ExprImpl::Literal(l2)) => l1 == l2,
348            (ExprImpl::InputRef(i1), ExprImpl::InputRef(i2)) => {
349                i1.index == scan_output_col_idx[i2.index]
350            }
351            (ExprImpl::FunctionCall(f1), ExprImpl::FunctionCall(f2)) => {
352                f1.func_type() == f2.func_type()
353                    && f1.return_type() == f2.return_type()
354                    && f1.inputs().len() == f2.inputs().len()
355                    && f1.inputs().iter().zip_eq_fast(f2.inputs()).all(|(e1, e2)| {
356                        Self::is_matched_vector_column_expr(e1, e2, scan_output_col_idx)
357                    })
358            }
359            _ => false,
360        }
361    }
362}
363
364impl ToBatch for LogicalVectorSearch {
365    fn to_batch(&self) -> crate::error::Result<BatchPlanRef> {
366        if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan()
367            && !scan.vector_indexes().is_empty()
368            && self
369                .core
370                .ctx()
371                .session_ctx()
372                .config()
373                .enable_index_selection()
374        {
375            for index in scan.vector_indexes() {
376                if !Self::is_matched_vector_column_expr(
377                    &index.vector_expr,
378                    vector_column_expr,
379                    scan.output_col_idx(),
380                ) {
381                    continue;
382                }
383                if index.vector_index_info.distance_type() != self.core.distance_type {
384                    continue;
385                }
386
387                let primary_table_cols_idx = self
388                    .core
389                    .output_indices
390                    .iter()
391                    .map(|input_idx| scan.output_col_idx()[*input_idx])
392                    .collect_vec();
393                let mut covered_table_cols_idx = Vec::new();
394                let mut non_covered_table_cols_idx = Vec::new();
395                let mut primary_table_col_in_output =
396                    Vec::with_capacity(primary_table_cols_idx.len());
397                for table_col_idx in &primary_table_cols_idx {
398                    if let Some(covered_info_column_idx) = index
399                        .primary_to_included_info_column_mapping
400                        .get(table_col_idx)
401                    {
402                        covered_table_cols_idx.push(*table_col_idx);
403                        primary_table_col_in_output.push((true, *covered_info_column_idx));
404                    } else {
405                        primary_table_col_in_output.push((false, non_covered_table_cols_idx.len()));
406                        non_covered_table_cols_idx.push(*table_col_idx);
407                    }
408                }
409                let vector_data_type = vector_expr.return_type();
410                let literal_vector_input = BatchValues::new(LogicalValues::new(
411                    vec![vec![vector_expr]],
412                    Schema::from_iter([Field::unnamed(vector_data_type)]),
413                    self.core.ctx(),
414                ))
415                .into();
416                let hnsw_ef_search = match index.vector_index_info.config.as_ref().unwrap() {
417                    vector_index_info::Config::Flat(_) => None,
418                    vector_index_info::Config::HnswFlat(_) => Some(
419                        self.core
420                            .ctx()
421                            .session_ctx()
422                            .config()
423                            .batch_hnsw_ef_search(),
424                    ),
425                };
426                let core = BatchVectorSearchCore {
427                    input: literal_vector_input,
428                    top_n: self.core.top_n,
429                    distance_type: self.core.distance_type,
430                    index_name: index.index_table.name.clone(),
431                    index_table_id: index.index_table.id,
432                    info_column_desc: index.index_table.columns
433                        [1..=index.included_info_columns.len()]
434                        .iter()
435                        .map(|col| col.column_desc.clone())
436                        .collect(),
437                    vector_column_idx: 0,
438                    hnsw_ef_search,
439                    ctx: self.core.ctx(),
440                };
441                let vector_search: BatchPlanRef = {
442                    let vector_search: BatchPlanRef = BatchVectorSearch::with_core(core).into();
443                    let unnested_array: BatchPlanRef = BatchProjectSet::new(generic::ProjectSet {
444                        select_list: vec![ExprImpl::TableFunction(
445                            TableFunction::new(
446                                TableFunctionType::Unnest,
447                                vec![ExprImpl::InputRef(
448                                    InputRef::new(1, vector_search.schema()[1].data_type()).into(),
449                                )],
450                            )?
451                            .into(),
452                        )],
453                        input: vector_search,
454                    })
455                    .into();
456                    let DataType::Struct(struct_type) = &unnested_array.schema()[1].data_type
457                    else {
458                        panic!("{:?}", unnested_array.schema()[1].data_type);
459                    };
460                    let unnest_struct = BatchProject::new(generic::Project::new(
461                        struct_type
462                            .types()
463                            .enumerate()
464                            .map(|(idx, data_type)| {
465                                ExprImpl::FunctionCall(
466                                    FunctionCall::new_unchecked(
467                                        ExprType::Field,
468                                        vec![
469                                            ExprImpl::InputRef(
470                                                InputRef::new(
471                                                    1,
472                                                    DataType::Struct(struct_type.clone()),
473                                                )
474                                                .into(),
475                                            ),
476                                            ExprImpl::Literal(
477                                                Literal::new(
478                                                    Some(ScalarImpl::Int32(idx as _)),
479                                                    DataType::Int32,
480                                                )
481                                                .into(),
482                                            ),
483                                        ],
484                                        data_type.clone(),
485                                    )
486                                    .into(),
487                                )
488                            })
489                            .collect(),
490                        unnested_array,
491                    ));
492                    unnest_struct.into()
493                };
494                let covered_output_col_idx = covered_table_cols_idx.iter().map(|table_col_idx| {
495                    index.primary_to_included_info_column_mapping[table_col_idx]
496                });
497                return Ok(if non_covered_table_cols_idx.is_empty() {
498                    BatchProject::new(generic::Project::with_out_col_idx(
499                        vector_search,
500                        covered_output_col_idx.chain([index.included_info_columns.len()]),
501                    ))
502                    .into()
503                } else {
504                    let mut primary_table_cols_idx = Vec::with_capacity(
505                        non_covered_table_cols_idx.len() + scan.table().pk().len(),
506                    );
507                    primary_table_cols_idx.extend(
508                        non_covered_table_cols_idx
509                            .iter()
510                            .cloned()
511                            .chain(scan.table().pk().iter().map(|order| order.column_index)),
512                    );
513                    let table_scan = generic::TableScan::new(
514                        primary_table_cols_idx,
515                        scan.table().clone(),
516                        vec![],
517                        vec![],
518                        self.core.input.ctx(),
519                        Condition::true_cond(),
520                        None,
521                    );
522                    let logical_scan = LogicalScan::from(table_scan);
523                    let batch_scan = logical_scan.to_batch()?;
524                    let vector_search_schema = vector_search.schema();
525                    let vector_search_schema_len = vector_search_schema.len();
526                    let on_condition = Condition {
527                        conjunctions: index
528                            .primary_key_idx_in_info_columns
529                            .iter()
530                            .zip_eq_debug(0..scan.table().pk().len())
531                            .map(|(pk_idx_in_info_columns, pk_idx)| {
532                                let batch_scan_pk_idx = vector_search_schema.len()
533                                    + non_covered_table_cols_idx.len()
534                                    + pk_idx;
535                                IndexSelectionRule::create_null_safe_equal_expr(
536                                    *pk_idx_in_info_columns,
537                                    vector_search_schema[*pk_idx_in_info_columns].data_type(),
538                                    batch_scan_pk_idx,
539                                    batch_scan.schema()[non_covered_table_cols_idx.len() + pk_idx]
540                                        .data_type(),
541                                )
542                            })
543                            .collect(),
544                    };
545                    let eq_predicate = EqJoinPredicate::create(
546                        vector_search_schema.len(),
547                        batch_scan.schema().len(),
548                        on_condition.clone(),
549                    );
550                    let join = generic::Join::new(
551                        vector_search,
552                        batch_scan,
553                        on_condition,
554                        JoinType::Inner,
555                        primary_table_col_in_output
556                            .iter()
557                            .map(|(covered, idx)| {
558                                if *covered {
559                                    *idx
560                                } else {
561                                    *idx + vector_search_schema_len
562                                }
563                            })
564                            // chain with distance column
565                            .chain([vector_search_schema_len - 1])
566                            .collect(),
567                    );
568                    let lookup_join = LogicalJoin::gen_batch_lookup_join(
569                        &logical_scan,
570                        eq_predicate,
571                        join,
572                        false,
573                    )?
574                    .ok_or_else(|| {
575                        ErrorCode::InternalError(
576                            "failed to convert to batch lookup join".to_owned(),
577                        )
578                    })?;
579                    lookup_join.into()
580                });
581            }
582        }
583        self.to_top_n().to_batch()
584    }
585}