1use std::sync::Arc;
16
17use itertools::Itertools;
18use pretty_xmlish::{Pretty, XmlNode};
19use risingwave_common::array::VECTOR_DISTANCE_TYPE;
20use risingwave_common::bail;
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::types::{DataType, ScalarImpl};
23use risingwave_common::util::column_index_mapping::ColIndexMapping;
24use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
25use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
26use risingwave_pb::common::PbDistanceType;
27use risingwave_pb::plan_common::JoinType;
28
29use crate::OptimizerContextRef;
30use crate::catalog::index_catalog::VectorIndex;
31use crate::error::ErrorCode;
32use crate::expr::{
33 Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
34 TableFunction, TableFunctionType, collect_input_refs,
35};
36use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
37use crate::optimizer::plan_node::generic::{
38 GenericPlanNode, GenericPlanRef, TopNLimit, VectorIndexLookupJoin, ensure_sorted_required_cols,
39};
40use crate::optimizer::plan_node::utils::{Distill, childless_record};
41use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
42use crate::optimizer::property::{FunctionalDependencySet, Order};
43use crate::optimizer::rule::IndexSelectionRule;
44use crate::utils::{ColIndexMappingRewriteExt, Condition};
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47struct VectorSearchCore {
48 top_n: u64,
49 distance_type: PbDistanceType,
50 left: ExprImpl,
51 right: ExprImpl,
52 output_indices: Vec<usize>,
55 input: PlanRef,
56}
57
58impl VectorSearchCore {
59 pub(crate) fn clone_with_input(&self, input: PlanRef) -> Self {
60 Self {
61 top_n: self.top_n,
62 distance_type: self.distance_type,
63 left: self.left.clone(),
64 right: self.right.clone(),
65 output_indices: self.output_indices.clone(),
66 input,
67 }
68 }
69
70 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
71 self.left = r.rewrite_expr(self.left.clone());
72 self.right = r.rewrite_expr(self.right.clone());
73 }
74
75 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
76 v.visit_expr(&self.left);
77 v.visit_expr(&self.right);
78 }
79
80 pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
81 let mut mapping = vec![None; self.input.schema().len()];
82 for (output_idx, input_idx) in self.output_indices.iter().enumerate() {
83 mapping[*input_idx] = Some(output_idx);
84 }
85 ColIndexMapping::new(mapping, self.output_indices.len() + 1)
86 }
87}
88
89impl GenericPlanNode for VectorSearchCore {
90 fn functional_dependency(&self) -> FunctionalDependencySet {
91 self.i2o_mapping()
92 .rewrite_functional_dependency_set(self.input.functional_dependency().clone())
93 }
94
95 fn schema(&self) -> Schema {
96 let fields = self
97 .output_indices
98 .iter()
99 .map(|idx| self.input.schema()[*idx].clone())
100 .chain([Field::new("vector_distance", DataType::Float64)])
101 .collect();
102 Schema { fields }
103 }
104
105 fn stream_key(&self) -> Option<Vec<usize>> {
106 self.input.stream_key().and_then(|v| {
107 let i2o_mapping = self.i2o_mapping();
108 v.iter().map(|idx| i2o_mapping.try_map(*idx)).collect()
109 })
110 }
111
112 fn ctx(&self) -> OptimizerContextRef {
113 self.input.ctx()
114 }
115}
116
117#[derive(Debug, Clone, PartialEq, Eq, Hash)]
118pub struct LogicalVectorSearch {
119 pub base: PlanBase<Logical>,
120 core: VectorSearchCore,
121}
122
123impl LogicalVectorSearch {
124 pub(crate) fn new(
125 top_n: u64,
126 distance_type: PbDistanceType,
127 left: ExprImpl,
128 right: ExprImpl,
129 input: PlanRef,
130 ) -> Self {
131 let output_indices = (0..input.schema().len()).collect();
132 let core = VectorSearchCore {
133 top_n,
134 distance_type,
135 left,
136 right,
137 output_indices,
138 input,
139 };
140 Self::with_core(core)
141 }
142
143 fn with_core(core: VectorSearchCore) -> Self {
144 let base = PlanBase::new_logical_with_core(&core);
145 Self { base, core }
146 }
147}
148
149impl_plan_tree_node_for_unary! { Logical, LogicalVectorSearch }
150
151impl PlanTreeNodeUnary<Logical> for LogicalVectorSearch {
152 fn input(&self) -> PlanRef {
153 self.core.input.clone()
154 }
155
156 fn clone_with_input(&self, input: PlanRef) -> Self {
157 let core = self.core.clone_with_input(input);
158 Self::with_core(core)
159 }
160}
161
162impl Distill for LogicalVectorSearch {
163 fn distill<'a>(&self) -> XmlNode<'a> {
164 let verbose = self.base.ctx().is_explain_verbose();
165 let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
166 vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
167 vec.push(("top_n", Pretty::debug(&self.core.top_n)));
168 vec.push(("left", Pretty::debug(&self.core.left)));
169 vec.push(("right", Pretty::debug(&self.core.right)));
170
171 if verbose {
172 vec.push((
173 "output_columns",
174 Pretty::Array(
175 self.core
176 .output_indices
177 .iter()
178 .map(|input_idx| {
179 Pretty::debug(&self.core.input.schema().fields()[*input_idx])
180 })
181 .collect(),
182 ),
183 ));
184 }
185
186 childless_record("LogicalVectorSearch", vec)
187 }
188}
189
190impl ColPrunable for LogicalVectorSearch {
191 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
192 let (project_exprs, required_cols) =
193 ensure_sorted_required_cols(required_cols, self.base.schema());
194 assert!(required_cols.is_sorted());
195 let input_schema = self.core.input.schema();
196 let mut required_input_idx_bitset =
197 collect_input_refs(input_schema.len(), [&self.core.left, &self.core.right]);
198 let mut non_distance_required_input_idx = Vec::new();
199 let require_distance_col = required_cols
200 .last()
201 .map(|last_col_idx| *last_col_idx == self.core.output_indices.len())
202 .unwrap_or(false);
203 let non_distance_iter_end_idx = if require_distance_col {
204 required_cols.len() - 1
205 } else {
206 required_cols.len()
207 };
208 for &required_col_idx in &required_cols[0..non_distance_iter_end_idx] {
209 let required_input_idx = self.core.output_indices[required_col_idx];
210 non_distance_required_input_idx.push(required_input_idx);
211 required_input_idx_bitset.set(required_col_idx, true);
212 }
213 let input_required_idx = required_input_idx_bitset.ones().collect_vec();
214
215 let new_input = self.input().prune_col(&input_required_idx, ctx);
216 let mut mapping = ColIndexMapping::with_remaining_columns(
218 &input_required_idx,
219 self.input().schema().len(),
220 );
221
222 let vector_search = {
223 let mut new_core = self.core.clone_with_input(new_input);
224 new_core.left = mapping.rewrite_expr(new_core.left);
225 new_core.right = mapping.rewrite_expr(new_core.right);
226 new_core.output_indices = non_distance_required_input_idx
227 .iter()
228 .map(|input_idx| mapping.map(*input_idx))
229 .collect();
230 Self::with_core(new_core)
231 };
232 LogicalProject::create(vector_search.into(), project_exprs)
233 }
234}
235
236impl ExprRewritable<Logical> for LogicalVectorSearch {
237 fn has_rewritable_expr(&self) -> bool {
238 true
239 }
240
241 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
242 let mut core = self.core.clone();
243 core.rewrite_exprs(r);
244 Self::with_core(core).into()
245 }
246}
247
248impl ExprVisitable for LogicalVectorSearch {
249 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
250 self.core.visit_exprs(v);
251 }
252}
253
254impl PredicatePushdown for LogicalVectorSearch {
255 fn predicate_pushdown(
256 &self,
257 predicate: Condition,
258 ctx: &mut PredicatePushdownContext,
259 ) -> PlanRef {
260 gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
261 }
262}
263
264impl ToStream for LogicalVectorSearch {
265 fn logical_rewrite_for_stream(
266 &self,
267 _ctx: &mut RewriteStreamContext,
268 ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
269 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
270 }
271
272 fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
273 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
274 }
275}
276
277impl LogicalVectorSearch {
278 pub fn to_top_n(
279 input: PlanRef,
280 left: ExprImpl,
281 right: ExprImpl,
282 distance_type: PbDistanceType,
283 top_n: u64,
284 output_indices: &[usize],
285 ) -> LogicalTopN {
286 let (neg, expr_type) = match distance_type {
287 PbDistanceType::Unspecified => {
288 unreachable!()
289 }
290 PbDistanceType::L1 => (false, ExprType::L1Distance),
291 PbDistanceType::L2Sqr => (false, ExprType::L2Distance),
292 PbDistanceType::Cosine => (false, ExprType::CosineDistance),
293 PbDistanceType::InnerProduct => (true, ExprType::InnerProduct),
294 };
295 let mut expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
296 expr_type,
297 vec![left, right],
298 VECTOR_DISTANCE_TYPE,
299 )));
300 if neg {
301 expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
302 ExprType::Neg,
303 vec![expr],
304 VECTOR_DISTANCE_TYPE,
305 )));
306 }
307 let exprs = generic::Project::out_col_idx_exprs(&input, output_indices.iter().copied())
308 .chain([expr])
309 .collect();
310
311 let input = LogicalProject::new(input, exprs).into();
312 let top_n = generic::TopN::without_group(
313 input,
314 TopNLimit::Simple(top_n),
315 0,
316 Order::new(vec![ColumnOrder::new(
317 output_indices.len(),
318 OrderType::ascending(),
319 )]),
320 );
321 top_n.into()
322 }
323
324 fn as_vector_table_scan(&self) -> Option<(&LogicalScan, ExprImpl, &ExprImpl)> {
325 let scan = self.core.input.as_logical_scan()?;
326 let left_const = (self.core.left.only_literal_and_func(), &self.core.left);
327 let right_const = (self.core.right.only_literal_and_func(), &self.core.right);
328 let (vector_column_expr, vector_expr) = match (left_const, right_const) {
329 ((true, _), (true, _)) => {
330 return None;
331 }
332 ((_, vector_column_expr), (true, vector_expr))
333 | ((true, vector_expr), (_, vector_column_expr)) => (vector_column_expr, vector_expr),
334 _ => return None,
335 };
336 Some((scan, vector_expr.clone(), vector_column_expr))
337 }
338
339 fn is_matched_vector_column_expr(
340 index_expr: &ExprImpl,
341 column_expr: &ExprImpl,
342 scan_output_col_idx: &[usize],
343 ) -> bool {
344 match (index_expr, column_expr) {
345 (ExprImpl::Literal(l1), ExprImpl::Literal(l2)) => l1 == l2,
346 (ExprImpl::InputRef(i1), ExprImpl::InputRef(i2)) => {
347 i1.index == scan_output_col_idx[i2.index]
348 }
349 (ExprImpl::FunctionCall(f1), ExprImpl::FunctionCall(f2)) => {
350 f1.func_type() == f2.func_type()
351 && f1.return_type() == f2.return_type()
352 && f1.inputs().len() == f2.inputs().len()
353 && f1.inputs().iter().zip_eq_fast(f2.inputs()).all(|(e1, e2)| {
354 Self::is_matched_vector_column_expr(e1, e2, scan_output_col_idx)
355 })
356 }
357 _ => false,
358 }
359 }
360}
361
362impl LogicalVectorSearch {
363 #[expect(clippy::type_complexity)]
364 pub fn resolve_vector_index_lookup<'a>(
365 scan: &'a LogicalScan,
366 vector_column_expr: &ExprImpl,
367 distance_type: PbDistanceType,
368 output_indices: &[usize],
369 ) -> Option<(
370 &'a Arc<VectorIndex>,
371 Vec<usize>,
372 Vec<usize>,
373 Vec<(bool, usize)>,
374 )> {
375 if !scan.predicate().always_true() {
376 return None;
377 }
378 if !scan.vector_indexes().is_empty() {
379 let primary_table_cols_idx = output_indices
380 .iter()
381 .map(|input_idx| scan.output_col_idx()[*input_idx])
382 .collect_vec();
383 for index in scan.vector_indexes() {
384 if !Self::is_matched_vector_column_expr(
385 &index.vector_expr,
386 vector_column_expr,
387 scan.output_col_idx(),
388 ) {
389 continue;
390 }
391 if index.vector_index_info.distance_type() != distance_type {
392 continue;
393 }
394
395 let mut covered_table_cols_idx = Vec::new();
396 let mut non_covered_table_cols_idx = Vec::new();
397 let mut primary_table_col_in_output =
398 Vec::with_capacity(primary_table_cols_idx.len());
399 for table_col_idx in &primary_table_cols_idx {
400 if let Some(covered_info_column_idx) = index
401 .primary_to_included_info_column_mapping
402 .get(table_col_idx)
403 {
404 covered_table_cols_idx.push(*table_col_idx);
405 primary_table_col_in_output.push((true, *covered_info_column_idx));
406 } else {
407 primary_table_col_in_output.push((false, non_covered_table_cols_idx.len()));
408 non_covered_table_cols_idx.push(*table_col_idx);
409 }
410 }
411 return Some((
412 index,
413 covered_table_cols_idx,
414 non_covered_table_cols_idx,
415 primary_table_col_in_output,
416 ));
417 }
418 }
419 None
420 }
421}
422
423impl ToBatch for LogicalVectorSearch {
424 fn to_batch(&self) -> Result<BatchPlanRef> {
425 if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan()
426 && self
427 .core
428 .ctx()
429 .session_ctx()
430 .config()
431 .enable_index_selection()
432 && let Some((
433 index,
434 covered_table_cols_idx,
435 non_covered_table_cols_idx,
436 primary_table_col_in_output,
437 )) = Self::resolve_vector_index_lookup(
438 scan,
439 vector_column_expr,
440 self.core.distance_type,
441 &self.core.output_indices,
442 )
443 {
444 {
445 let vector_data_type = vector_expr.return_type();
446 let literal_vector_input = BatchValues::new(LogicalValues::new(
447 vec![vec![vector_expr]],
448 Schema::from_iter([Field::new("query_vector", vector_data_type)]),
449 self.core.ctx(),
450 ))
451 .into();
452 let hnsw_ef_search =
453 index.resolve_hnsw_ef_search(&self.core.ctx().session_ctx().config());
454 let info_column_desc = index.info_column_desc();
455 let core = VectorIndexLookupJoin {
456 input: literal_vector_input,
457 top_n: self.core.top_n,
458 distance_type: self.core.distance_type,
459 index_name: index.index_table.name.clone(),
460 index_table_id: index.index_table.id,
461 info_output_indices: (0..info_column_desc.len()).collect(),
462 info_column_desc,
463 include_distance: true,
464 as_of: scan.as_of(),
465 vector_column_idx: 0,
466 hnsw_ef_search,
467 ctx: self.core.ctx(),
468 };
469 let vector_search: BatchPlanRef = {
470 let vector_search: BatchPlanRef = BatchVectorSearch::with_core(core).into();
471 let unnested_array: BatchPlanRef = BatchProjectSet::new(generic::ProjectSet {
472 select_list: vec![ExprImpl::TableFunction(
473 TableFunction::new(
474 TableFunctionType::Unnest,
475 vec![ExprImpl::InputRef(
476 InputRef::new(1, vector_search.schema()[1].data_type()).into(),
477 )],
478 )?
479 .into(),
480 )],
481 input: vector_search,
482 })
483 .into();
484 let DataType::Struct(struct_type) = &unnested_array.schema()[1].data_type
485 else {
486 panic!("{:?}", unnested_array.schema()[1].data_type);
487 };
488 let unnest_struct = BatchProject::new(generic::Project::new(
489 struct_type
490 .types()
491 .enumerate()
492 .map(|(idx, data_type)| {
493 ExprImpl::FunctionCall(
494 FunctionCall::new_unchecked(
495 ExprType::Field,
496 vec![
497 ExprImpl::InputRef(
498 InputRef::new(
499 1,
500 DataType::Struct(struct_type.clone()),
501 )
502 .into(),
503 ),
504 ExprImpl::Literal(
505 Literal::new(
506 Some(ScalarImpl::Int32(idx as _)),
507 DataType::Int32,
508 )
509 .into(),
510 ),
511 ],
512 data_type.clone(),
513 )
514 .into(),
515 )
516 })
517 .collect(),
518 unnested_array,
519 ));
520 unnest_struct.into()
521 };
522 let covered_output_col_idx = covered_table_cols_idx.iter().map(|table_col_idx| {
523 index.primary_to_included_info_column_mapping[table_col_idx]
524 });
525 return Ok(if non_covered_table_cols_idx.is_empty() {
526 BatchProject::new(generic::Project::with_out_col_idx(
527 vector_search,
528 covered_output_col_idx.chain([index.included_info_columns.len()]),
529 ))
530 .into()
531 } else {
532 let mut primary_table_cols_idx = Vec::with_capacity(
533 non_covered_table_cols_idx.len() + scan.table().pk().len(),
534 );
535 primary_table_cols_idx.extend(
536 non_covered_table_cols_idx
537 .iter()
538 .cloned()
539 .chain(scan.table().pk().iter().map(|order| order.column_index)),
540 );
541 let table_scan = generic::TableScan::new(
542 primary_table_cols_idx,
543 scan.table().clone(),
544 vec![],
545 vec![],
546 self.core.input.ctx(),
547 Condition::true_cond(),
548 scan.as_of(),
549 );
550 let logical_scan = LogicalScan::from(table_scan);
551 let batch_scan = logical_scan.to_batch()?;
552 let vector_search_schema = vector_search.schema();
553 let vector_search_schema_len = vector_search_schema.len();
554 let on_condition = Condition {
555 conjunctions: index
556 .primary_key_idx_in_info_columns
557 .iter()
558 .zip_eq_debug(0..scan.table().pk().len())
559 .map(|(pk_idx_in_info_columns, pk_idx)| {
560 let batch_scan_pk_idx = vector_search_schema.len()
561 + non_covered_table_cols_idx.len()
562 + pk_idx;
563 IndexSelectionRule::create_null_safe_equal_expr(
564 *pk_idx_in_info_columns,
565 vector_search_schema[*pk_idx_in_info_columns].data_type(),
566 batch_scan_pk_idx,
567 batch_scan.schema()[non_covered_table_cols_idx.len() + pk_idx]
568 .data_type(),
569 )
570 })
571 .collect(),
572 };
573 let eq_predicate = EqJoinPredicate::create(
574 vector_search_schema.len(),
575 batch_scan.schema().len(),
576 on_condition.clone(),
577 );
578 let join = generic::Join::new(
579 vector_search,
580 batch_scan,
581 on_condition,
582 JoinType::Inner,
583 primary_table_col_in_output
584 .iter()
585 .map(|(covered, idx)| {
586 if *covered {
587 *idx
588 } else {
589 *idx + vector_search_schema_len
590 }
591 })
592 .chain([vector_search_schema_len - 1])
594 .collect(),
595 );
596 let lookup_join = LogicalJoin::gen_batch_lookup_join(
597 &logical_scan,
598 eq_predicate,
599 join,
600 false,
601 )?
602 .ok_or_else(|| {
603 ErrorCode::InternalError(
604 "failed to convert to batch lookup join".to_owned(),
605 )
606 })?;
607 lookup_join.into()
608 });
609 }
610 }
611 Self::to_top_n(
612 self.core.input.clone(),
613 self.core.left.clone(),
614 self.core.right.clone(),
615 self.core.distance_type,
616 self.core.top_n,
617 &self.core.output_indices,
618 )
619 .to_batch()
620 }
621}