1use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::array::VECTOR_DISTANCE_TYPE;
18use risingwave_common::bail;
19use risingwave_common::catalog::{Field, Schema};
20use risingwave_common::types::{DataType, ScalarImpl};
21use risingwave_common::util::column_index_mapping::ColIndexMapping;
22use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
23use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
24use risingwave_pb::catalog::vector_index_info;
25use risingwave_pb::common::PbDistanceType;
26use risingwave_pb::plan_common::JoinType;
27
28use crate::OptimizerContextRef;
29use crate::error::ErrorCode;
30use crate::expr::{
31 Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32 TableFunction, TableFunctionType, collect_input_refs,
33};
34use crate::optimizer::plan_node::batch_vector_search::BatchVectorSearchCore;
35use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
36use crate::optimizer::plan_node::generic::{
37 GenericPlanNode, GenericPlanRef, TopNLimit, ensure_sorted_required_cols,
38};
39use crate::optimizer::plan_node::utils::{Distill, childless_record};
40use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
41use crate::optimizer::property::{FunctionalDependencySet, Order};
42use crate::optimizer::rule::IndexSelectionRule;
43use crate::utils::{ColIndexMappingRewriteExt, Condition};
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46struct VectorSearchCore {
47 top_n: u64,
48 distance_type: PbDistanceType,
49 left: ExprImpl,
50 right: ExprImpl,
51 output_indices: Vec<usize>,
54 input: PlanRef,
55}
56
57impl VectorSearchCore {
58 pub(crate) fn clone_with_input(&self, input: PlanRef) -> Self {
59 Self {
60 top_n: self.top_n,
61 distance_type: self.distance_type,
62 left: self.left.clone(),
63 right: self.right.clone(),
64 output_indices: self.output_indices.clone(),
65 input,
66 }
67 }
68
69 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
70 self.left = r.rewrite_expr(self.left.clone());
71 self.right = r.rewrite_expr(self.right.clone());
72 }
73
74 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
75 v.visit_expr(&self.left);
76 v.visit_expr(&self.right);
77 }
78
79 pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
80 let mut mapping = vec![None; self.input.schema().len()];
81 for (output_idx, input_idx) in self.output_indices.iter().enumerate() {
82 mapping[*input_idx] = Some(output_idx);
83 }
84 ColIndexMapping::new(mapping, self.output_indices.len() + 1)
85 }
86}
87
88impl GenericPlanNode for VectorSearchCore {
89 fn functional_dependency(&self) -> FunctionalDependencySet {
90 self.i2o_mapping()
91 .rewrite_functional_dependency_set(self.input.functional_dependency().clone())
92 }
93
94 fn schema(&self) -> Schema {
95 let fields = self
96 .output_indices
97 .iter()
98 .map(|idx| self.input.schema()[*idx].clone())
99 .chain([Field::new("vector_distance", DataType::Float64)])
100 .collect();
101 Schema { fields }
102 }
103
104 fn stream_key(&self) -> Option<Vec<usize>> {
105 self.input.stream_key().and_then(|v| {
106 let i2o_mapping = self.i2o_mapping();
107 v.iter().map(|idx| i2o_mapping.try_map(*idx)).collect()
108 })
109 }
110
111 fn ctx(&self) -> OptimizerContextRef {
112 self.input.ctx()
113 }
114}
115
116#[derive(Debug, Clone, PartialEq, Eq, Hash)]
117pub struct LogicalVectorSearch {
118 pub base: PlanBase<Logical>,
119 core: VectorSearchCore,
120}
121
122impl LogicalVectorSearch {
123 pub(crate) fn new(
124 top_n: u64,
125 distance_type: PbDistanceType,
126 left: ExprImpl,
127 right: ExprImpl,
128 output_indices: Vec<usize>,
129 input: PlanRef,
130 ) -> Self {
131 let core = VectorSearchCore {
132 top_n,
133 distance_type,
134 left,
135 right,
136 output_indices,
137 input,
138 };
139 Self::with_core(core)
140 }
141
142 fn with_core(core: VectorSearchCore) -> Self {
143 let base = PlanBase::new_logical_with_core(&core);
144 Self { base, core }
145 }
146
147 pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
148 self.core.i2o_mapping()
149 }
150}
151
152impl_plan_tree_node_for_unary! { Logical, LogicalVectorSearch }
153
154impl PlanTreeNodeUnary<Logical> for LogicalVectorSearch {
155 fn input(&self) -> PlanRef {
156 self.core.input.clone()
157 }
158
159 fn clone_with_input(&self, input: PlanRef) -> Self {
160 let core = self.core.clone_with_input(input);
161 Self::with_core(core)
162 }
163}
164
165impl Distill for LogicalVectorSearch {
166 fn distill<'a>(&self) -> XmlNode<'a> {
167 let verbose = self.base.ctx().is_explain_verbose();
168 let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
169 vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
170 vec.push(("top_n", Pretty::debug(&self.core.top_n)));
171 vec.push(("left", Pretty::debug(&self.core.left)));
172 vec.push(("right", Pretty::debug(&self.core.right)));
173
174 if verbose {
175 vec.push((
176 "output_columns",
177 Pretty::Array(
178 self.core
179 .output_indices
180 .iter()
181 .map(|input_idx| {
182 Pretty::debug(&self.core.input.schema().fields()[*input_idx])
183 })
184 .collect(),
185 ),
186 ));
187 }
188
189 childless_record("LogicalVectorSearch", vec)
190 }
191}
192
193impl ColPrunable for LogicalVectorSearch {
194 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
195 let (project_exprs, required_cols) =
196 ensure_sorted_required_cols(required_cols, self.base.schema());
197 assert!(required_cols.is_sorted());
198 let input_schema = self.core.input.schema();
199 let mut required_input_idx_bitset =
200 collect_input_refs(input_schema.len(), [&self.core.left, &self.core.right]);
201 let mut non_distance_required_input_idx = Vec::new();
202 let require_distance_col = required_cols
203 .last()
204 .map(|last_col_idx| *last_col_idx == self.core.output_indices.len())
205 .unwrap_or(false);
206 let non_distance_iter_end_idx = if require_distance_col {
207 required_cols.len() - 1
208 } else {
209 required_cols.len()
210 };
211 for &required_col_idx in &required_cols[0..non_distance_iter_end_idx] {
212 let required_input_idx = self.core.output_indices[required_col_idx];
213 non_distance_required_input_idx.push(required_input_idx);
214 required_input_idx_bitset.set(required_col_idx, true);
215 }
216 let input_required_idx = required_input_idx_bitset.ones().collect_vec();
217
218 let new_input = self.input().prune_col(&input_required_idx, ctx);
219 let mut mapping = ColIndexMapping::with_remaining_columns(
221 &input_required_idx,
222 self.input().schema().len(),
223 );
224
225 let vector_search = {
226 let mut new_core = self.core.clone_with_input(new_input);
227 new_core.left = mapping.rewrite_expr(new_core.left);
228 new_core.right = mapping.rewrite_expr(new_core.right);
229 new_core.output_indices = non_distance_required_input_idx
230 .iter()
231 .map(|input_idx| mapping.map(*input_idx))
232 .collect();
233 Self::with_core(new_core)
234 };
235 LogicalProject::create(vector_search.into(), project_exprs)
236 }
237}
238
239impl ExprRewritable<Logical> for LogicalVectorSearch {
240 fn has_rewritable_expr(&self) -> bool {
241 true
242 }
243
244 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
245 let mut core = self.core.clone();
246 core.rewrite_exprs(r);
247 Self::with_core(core).into()
248 }
249}
250
251impl ExprVisitable for LogicalVectorSearch {
252 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
253 self.core.visit_exprs(v);
254 }
255}
256
257impl PredicatePushdown for LogicalVectorSearch {
258 fn predicate_pushdown(
259 &self,
260 predicate: Condition,
261 ctx: &mut PredicatePushdownContext,
262 ) -> PlanRef {
263 gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
264 }
265}
266
267impl ToStream for LogicalVectorSearch {
268 fn logical_rewrite_for_stream(
269 &self,
270 _ctx: &mut RewriteStreamContext,
271 ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
272 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
273 }
274
275 fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
276 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
277 }
278}
279
280impl LogicalVectorSearch {
281 fn to_top_n(&self) -> LogicalTopN {
282 let (neg, expr_type) = match self.core.distance_type {
283 PbDistanceType::Unspecified => {
284 unreachable!()
285 }
286 PbDistanceType::L1 => (false, ExprType::L1Distance),
287 PbDistanceType::L2Sqr => (false, ExprType::L2Distance),
288 PbDistanceType::Cosine => (false, ExprType::CosineDistance),
289 PbDistanceType::InnerProduct => (true, ExprType::InnerProduct),
290 };
291 let mut expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
292 expr_type,
293 vec![self.core.left.clone(), self.core.right.clone()],
294 VECTOR_DISTANCE_TYPE,
295 )));
296 if neg {
297 expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
298 ExprType::Neg,
299 vec![expr],
300 VECTOR_DISTANCE_TYPE,
301 )));
302 }
303 let exprs = generic::Project::out_col_idx_exprs(
304 &self.core.input,
305 self.core.output_indices.iter().copied(),
306 )
307 .chain([expr])
308 .collect();
309
310 let input = LogicalProject::new(self.input(), exprs).into();
311 let top_n = generic::TopN::without_group(
312 input,
313 TopNLimit::Simple(self.core.top_n),
314 0,
315 Order::new(vec![ColumnOrder::new(
316 self.core.output_indices.len(),
317 OrderType::ascending(),
318 )]),
319 );
320 top_n.into()
321 }
322
323 fn as_vector_table_scan(&self) -> Option<(&LogicalScan, ExprImpl, &ExprImpl)> {
324 let scan = self.core.input.as_logical_scan()?;
325 if !scan.predicate().always_true() {
326 return None;
327 }
328 let left_const = (self.core.left.only_literal_and_func(), &self.core.left);
329 let right_const = (self.core.right.only_literal_and_func(), &self.core.right);
330 let (vector_column_expr, vector_expr) = match (left_const, right_const) {
331 ((true, _), (true, _)) => {
332 return None;
333 }
334 ((_, vector_column_expr), (true, vector_expr))
335 | ((true, vector_expr), (_, vector_column_expr)) => (vector_column_expr, vector_expr),
336 _ => return None,
337 };
338 Some((scan, vector_expr.clone(), vector_column_expr))
339 }
340
341 fn is_matched_vector_column_expr(
342 index_expr: &ExprImpl,
343 column_expr: &ExprImpl,
344 scan_output_col_idx: &[usize],
345 ) -> bool {
346 match (index_expr, column_expr) {
347 (ExprImpl::Literal(l1), ExprImpl::Literal(l2)) => l1 == l2,
348 (ExprImpl::InputRef(i1), ExprImpl::InputRef(i2)) => {
349 i1.index == scan_output_col_idx[i2.index]
350 }
351 (ExprImpl::FunctionCall(f1), ExprImpl::FunctionCall(f2)) => {
352 f1.func_type() == f2.func_type()
353 && f1.return_type() == f2.return_type()
354 && f1.inputs().len() == f2.inputs().len()
355 && f1.inputs().iter().zip_eq_fast(f2.inputs()).all(|(e1, e2)| {
356 Self::is_matched_vector_column_expr(e1, e2, scan_output_col_idx)
357 })
358 }
359 _ => false,
360 }
361 }
362}
363
364impl ToBatch for LogicalVectorSearch {
365 fn to_batch(&self) -> crate::error::Result<BatchPlanRef> {
366 if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan()
367 && !scan.vector_indexes().is_empty()
368 && self
369 .core
370 .ctx()
371 .session_ctx()
372 .config()
373 .enable_index_selection()
374 {
375 for index in scan.vector_indexes() {
376 if !Self::is_matched_vector_column_expr(
377 &index.vector_expr,
378 vector_column_expr,
379 scan.output_col_idx(),
380 ) {
381 continue;
382 }
383 if index.vector_index_info.distance_type() != self.core.distance_type {
384 continue;
385 }
386
387 let primary_table_cols_idx = self
388 .core
389 .output_indices
390 .iter()
391 .map(|input_idx| scan.output_col_idx()[*input_idx])
392 .collect_vec();
393 let mut covered_table_cols_idx = Vec::new();
394 let mut non_covered_table_cols_idx = Vec::new();
395 let mut primary_table_col_in_output =
396 Vec::with_capacity(primary_table_cols_idx.len());
397 for table_col_idx in &primary_table_cols_idx {
398 if let Some(covered_info_column_idx) = index
399 .primary_to_included_info_column_mapping
400 .get(table_col_idx)
401 {
402 covered_table_cols_idx.push(*table_col_idx);
403 primary_table_col_in_output.push((true, *covered_info_column_idx));
404 } else {
405 primary_table_col_in_output.push((false, non_covered_table_cols_idx.len()));
406 non_covered_table_cols_idx.push(*table_col_idx);
407 }
408 }
409 let vector_data_type = vector_expr.return_type();
410 let literal_vector_input = BatchValues::new(LogicalValues::new(
411 vec![vec![vector_expr]],
412 Schema::from_iter([Field::unnamed(vector_data_type)]),
413 self.core.ctx(),
414 ))
415 .into();
416 let hnsw_ef_search = match index.vector_index_info.config.as_ref().unwrap() {
417 vector_index_info::Config::Flat(_) => None,
418 vector_index_info::Config::HnswFlat(_) => Some(
419 self.core
420 .ctx()
421 .session_ctx()
422 .config()
423 .batch_hnsw_ef_search(),
424 ),
425 };
426 let core = BatchVectorSearchCore {
427 input: literal_vector_input,
428 top_n: self.core.top_n,
429 distance_type: self.core.distance_type,
430 index_name: index.index_table.name.clone(),
431 index_table_id: index.index_table.id,
432 info_column_desc: index.index_table.columns
433 [1..=index.included_info_columns.len()]
434 .iter()
435 .map(|col| col.column_desc.clone())
436 .collect(),
437 vector_column_idx: 0,
438 hnsw_ef_search,
439 ctx: self.core.ctx(),
440 };
441 let vector_search: BatchPlanRef = {
442 let vector_search: BatchPlanRef = BatchVectorSearch::with_core(core).into();
443 let unnested_array: BatchPlanRef = BatchProjectSet::new(generic::ProjectSet {
444 select_list: vec![ExprImpl::TableFunction(
445 TableFunction::new(
446 TableFunctionType::Unnest,
447 vec![ExprImpl::InputRef(
448 InputRef::new(1, vector_search.schema()[1].data_type()).into(),
449 )],
450 )?
451 .into(),
452 )],
453 input: vector_search,
454 })
455 .into();
456 let DataType::Struct(struct_type) = &unnested_array.schema()[1].data_type
457 else {
458 panic!("{:?}", unnested_array.schema()[1].data_type);
459 };
460 let unnest_struct = BatchProject::new(generic::Project::new(
461 struct_type
462 .types()
463 .enumerate()
464 .map(|(idx, data_type)| {
465 ExprImpl::FunctionCall(
466 FunctionCall::new_unchecked(
467 ExprType::Field,
468 vec![
469 ExprImpl::InputRef(
470 InputRef::new(
471 1,
472 DataType::Struct(struct_type.clone()),
473 )
474 .into(),
475 ),
476 ExprImpl::Literal(
477 Literal::new(
478 Some(ScalarImpl::Int32(idx as _)),
479 DataType::Int32,
480 )
481 .into(),
482 ),
483 ],
484 data_type.clone(),
485 )
486 .into(),
487 )
488 })
489 .collect(),
490 unnested_array,
491 ));
492 unnest_struct.into()
493 };
494 let covered_output_col_idx = covered_table_cols_idx.iter().map(|table_col_idx| {
495 index.primary_to_included_info_column_mapping[table_col_idx]
496 });
497 return Ok(if non_covered_table_cols_idx.is_empty() {
498 BatchProject::new(generic::Project::with_out_col_idx(
499 vector_search,
500 covered_output_col_idx.chain([index.included_info_columns.len()]),
501 ))
502 .into()
503 } else {
504 let mut primary_table_cols_idx = Vec::with_capacity(
505 non_covered_table_cols_idx.len() + scan.table().pk().len(),
506 );
507 primary_table_cols_idx.extend(
508 non_covered_table_cols_idx
509 .iter()
510 .cloned()
511 .chain(scan.table().pk().iter().map(|order| order.column_index)),
512 );
513 let table_scan = generic::TableScan::new(
514 primary_table_cols_idx,
515 scan.table().clone(),
516 vec![],
517 vec![],
518 self.core.input.ctx(),
519 Condition::true_cond(),
520 None,
521 );
522 let logical_scan = LogicalScan::from(table_scan);
523 let batch_scan = logical_scan.to_batch()?;
524 let vector_search_schema = vector_search.schema();
525 let vector_search_schema_len = vector_search_schema.len();
526 let on_condition = Condition {
527 conjunctions: index
528 .primary_key_idx_in_info_columns
529 .iter()
530 .zip_eq_debug(0..scan.table().pk().len())
531 .map(|(pk_idx_in_info_columns, pk_idx)| {
532 let batch_scan_pk_idx = vector_search_schema.len()
533 + non_covered_table_cols_idx.len()
534 + pk_idx;
535 IndexSelectionRule::create_null_safe_equal_expr(
536 *pk_idx_in_info_columns,
537 vector_search_schema[*pk_idx_in_info_columns].data_type(),
538 batch_scan_pk_idx,
539 batch_scan.schema()[non_covered_table_cols_idx.len() + pk_idx]
540 .data_type(),
541 )
542 })
543 .collect(),
544 };
545 let eq_predicate = EqJoinPredicate::create(
546 vector_search_schema.len(),
547 batch_scan.schema().len(),
548 on_condition.clone(),
549 );
550 let join = generic::Join::new(
551 vector_search,
552 batch_scan,
553 on_condition,
554 JoinType::Inner,
555 primary_table_col_in_output
556 .iter()
557 .map(|(covered, idx)| {
558 if *covered {
559 *idx
560 } else {
561 *idx + vector_search_schema_len
562 }
563 })
564 .chain([vector_search_schema_len - 1])
566 .collect(),
567 );
568 let lookup_join = LogicalJoin::gen_batch_lookup_join(
569 &logical_scan,
570 eq_predicate,
571 join,
572 false,
573 )?
574 .ok_or_else(|| {
575 ErrorCode::InternalError(
576 "failed to convert to batch lookup join".to_owned(),
577 )
578 })?;
579 lookup_join.into()
580 });
581 }
582 }
583 self.to_top_n().to_batch()
584 }
585}