1use itertools::Itertools;
16use pretty_xmlish::{Pretty, XmlNode};
17use risingwave_common::array::VECTOR_DISTANCE_TYPE;
18use risingwave_common::bail;
19use risingwave_common::catalog::{Field, Schema};
20use risingwave_common::types::{DataType, ScalarImpl};
21use risingwave_common::util::column_index_mapping::ColIndexMapping;
22use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
23use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
24use risingwave_pb::catalog::vector_index_info;
25use risingwave_pb::common::PbDistanceType;
26use risingwave_pb::plan_common::JoinType;
27
28use crate::OptimizerContextRef;
29use crate::error::ErrorCode;
30use crate::expr::{
31 Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32 TableFunction, TableFunctionType, collect_input_refs,
33};
34use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
35use crate::optimizer::plan_node::generic::{
36 GenericPlanNode, GenericPlanRef, TopNLimit, VectorIndexLookupJoin, ensure_sorted_required_cols,
37};
38use crate::optimizer::plan_node::utils::{Distill, childless_record};
39use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
40use crate::optimizer::property::{FunctionalDependencySet, Order};
41use crate::optimizer::rule::IndexSelectionRule;
42use crate::utils::{ColIndexMappingRewriteExt, Condition};
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45struct VectorSearchCore {
46 top_n: u64,
47 distance_type: PbDistanceType,
48 left: ExprImpl,
49 right: ExprImpl,
50 output_indices: Vec<usize>,
53 input: PlanRef,
54}
55
56impl VectorSearchCore {
57 pub(crate) fn clone_with_input(&self, input: PlanRef) -> Self {
58 Self {
59 top_n: self.top_n,
60 distance_type: self.distance_type,
61 left: self.left.clone(),
62 right: self.right.clone(),
63 output_indices: self.output_indices.clone(),
64 input,
65 }
66 }
67
68 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
69 self.left = r.rewrite_expr(self.left.clone());
70 self.right = r.rewrite_expr(self.right.clone());
71 }
72
73 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
74 v.visit_expr(&self.left);
75 v.visit_expr(&self.right);
76 }
77
78 pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
79 let mut mapping = vec![None; self.input.schema().len()];
80 for (output_idx, input_idx) in self.output_indices.iter().enumerate() {
81 mapping[*input_idx] = Some(output_idx);
82 }
83 ColIndexMapping::new(mapping, self.output_indices.len() + 1)
84 }
85}
86
87impl GenericPlanNode for VectorSearchCore {
88 fn functional_dependency(&self) -> FunctionalDependencySet {
89 self.i2o_mapping()
90 .rewrite_functional_dependency_set(self.input.functional_dependency().clone())
91 }
92
93 fn schema(&self) -> Schema {
94 let fields = self
95 .output_indices
96 .iter()
97 .map(|idx| self.input.schema()[*idx].clone())
98 .chain([Field::new("vector_distance", DataType::Float64)])
99 .collect();
100 Schema { fields }
101 }
102
103 fn stream_key(&self) -> Option<Vec<usize>> {
104 self.input.stream_key().and_then(|v| {
105 let i2o_mapping = self.i2o_mapping();
106 v.iter().map(|idx| i2o_mapping.try_map(*idx)).collect()
107 })
108 }
109
110 fn ctx(&self) -> OptimizerContextRef {
111 self.input.ctx()
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct LogicalVectorSearch {
117 pub base: PlanBase<Logical>,
118 core: VectorSearchCore,
119}
120
121impl LogicalVectorSearch {
122 pub(crate) fn new(
123 top_n: u64,
124 distance_type: PbDistanceType,
125 left: ExprImpl,
126 right: ExprImpl,
127 output_indices: Vec<usize>,
128 input: PlanRef,
129 ) -> Self {
130 let core = VectorSearchCore {
131 top_n,
132 distance_type,
133 left,
134 right,
135 output_indices,
136 input,
137 };
138 Self::with_core(core)
139 }
140
141 fn with_core(core: VectorSearchCore) -> Self {
142 let base = PlanBase::new_logical_with_core(&core);
143 Self { base, core }
144 }
145
146 pub(crate) fn i2o_mapping(&self) -> ColIndexMapping {
147 self.core.i2o_mapping()
148 }
149}
150
151impl_plan_tree_node_for_unary! { Logical, LogicalVectorSearch }
152
153impl PlanTreeNodeUnary<Logical> for LogicalVectorSearch {
154 fn input(&self) -> PlanRef {
155 self.core.input.clone()
156 }
157
158 fn clone_with_input(&self, input: PlanRef) -> Self {
159 let core = self.core.clone_with_input(input);
160 Self::with_core(core)
161 }
162}
163
164impl Distill for LogicalVectorSearch {
165 fn distill<'a>(&self) -> XmlNode<'a> {
166 let verbose = self.base.ctx().is_explain_verbose();
167 let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
168 vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
169 vec.push(("top_n", Pretty::debug(&self.core.top_n)));
170 vec.push(("left", Pretty::debug(&self.core.left)));
171 vec.push(("right", Pretty::debug(&self.core.right)));
172
173 if verbose {
174 vec.push((
175 "output_columns",
176 Pretty::Array(
177 self.core
178 .output_indices
179 .iter()
180 .map(|input_idx| {
181 Pretty::debug(&self.core.input.schema().fields()[*input_idx])
182 })
183 .collect(),
184 ),
185 ));
186 }
187
188 childless_record("LogicalVectorSearch", vec)
189 }
190}
191
192impl ColPrunable for LogicalVectorSearch {
193 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
194 let (project_exprs, required_cols) =
195 ensure_sorted_required_cols(required_cols, self.base.schema());
196 assert!(required_cols.is_sorted());
197 let input_schema = self.core.input.schema();
198 let mut required_input_idx_bitset =
199 collect_input_refs(input_schema.len(), [&self.core.left, &self.core.right]);
200 let mut non_distance_required_input_idx = Vec::new();
201 let require_distance_col = required_cols
202 .last()
203 .map(|last_col_idx| *last_col_idx == self.core.output_indices.len())
204 .unwrap_or(false);
205 let non_distance_iter_end_idx = if require_distance_col {
206 required_cols.len() - 1
207 } else {
208 required_cols.len()
209 };
210 for &required_col_idx in &required_cols[0..non_distance_iter_end_idx] {
211 let required_input_idx = self.core.output_indices[required_col_idx];
212 non_distance_required_input_idx.push(required_input_idx);
213 required_input_idx_bitset.set(required_col_idx, true);
214 }
215 let input_required_idx = required_input_idx_bitset.ones().collect_vec();
216
217 let new_input = self.input().prune_col(&input_required_idx, ctx);
218 let mut mapping = ColIndexMapping::with_remaining_columns(
220 &input_required_idx,
221 self.input().schema().len(),
222 );
223
224 let vector_search = {
225 let mut new_core = self.core.clone_with_input(new_input);
226 new_core.left = mapping.rewrite_expr(new_core.left);
227 new_core.right = mapping.rewrite_expr(new_core.right);
228 new_core.output_indices = non_distance_required_input_idx
229 .iter()
230 .map(|input_idx| mapping.map(*input_idx))
231 .collect();
232 Self::with_core(new_core)
233 };
234 LogicalProject::create(vector_search.into(), project_exprs)
235 }
236}
237
238impl ExprRewritable<Logical> for LogicalVectorSearch {
239 fn has_rewritable_expr(&self) -> bool {
240 true
241 }
242
243 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
244 let mut core = self.core.clone();
245 core.rewrite_exprs(r);
246 Self::with_core(core).into()
247 }
248}
249
250impl ExprVisitable for LogicalVectorSearch {
251 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
252 self.core.visit_exprs(v);
253 }
254}
255
256impl PredicatePushdown for LogicalVectorSearch {
257 fn predicate_pushdown(
258 &self,
259 predicate: Condition,
260 ctx: &mut PredicatePushdownContext,
261 ) -> PlanRef {
262 gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
263 }
264}
265
266impl ToStream for LogicalVectorSearch {
267 fn logical_rewrite_for_stream(
268 &self,
269 _ctx: &mut RewriteStreamContext,
270 ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
271 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
272 }
273
274 fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
275 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
276 }
277}
278
279impl LogicalVectorSearch {
280 fn to_top_n(&self) -> LogicalTopN {
281 let (neg, expr_type) = match self.core.distance_type {
282 PbDistanceType::Unspecified => {
283 unreachable!()
284 }
285 PbDistanceType::L1 => (false, ExprType::L1Distance),
286 PbDistanceType::L2Sqr => (false, ExprType::L2Distance),
287 PbDistanceType::Cosine => (false, ExprType::CosineDistance),
288 PbDistanceType::InnerProduct => (true, ExprType::InnerProduct),
289 };
290 let mut expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
291 expr_type,
292 vec![self.core.left.clone(), self.core.right.clone()],
293 VECTOR_DISTANCE_TYPE,
294 )));
295 if neg {
296 expr = ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
297 ExprType::Neg,
298 vec![expr],
299 VECTOR_DISTANCE_TYPE,
300 )));
301 }
302 let exprs = generic::Project::out_col_idx_exprs(
303 &self.core.input,
304 self.core.output_indices.iter().copied(),
305 )
306 .chain([expr])
307 .collect();
308
309 let input = LogicalProject::new(self.input(), exprs).into();
310 let top_n = generic::TopN::without_group(
311 input,
312 TopNLimit::Simple(self.core.top_n),
313 0,
314 Order::new(vec![ColumnOrder::new(
315 self.core.output_indices.len(),
316 OrderType::ascending(),
317 )]),
318 );
319 top_n.into()
320 }
321
322 fn as_vector_table_scan(&self) -> Option<(&LogicalScan, ExprImpl, &ExprImpl)> {
323 let scan = self.core.input.as_logical_scan()?;
324 if !scan.predicate().always_true() {
325 return None;
326 }
327 let left_const = (self.core.left.only_literal_and_func(), &self.core.left);
328 let right_const = (self.core.right.only_literal_and_func(), &self.core.right);
329 let (vector_column_expr, vector_expr) = match (left_const, right_const) {
330 ((true, _), (true, _)) => {
331 return None;
332 }
333 ((_, vector_column_expr), (true, vector_expr))
334 | ((true, vector_expr), (_, vector_column_expr)) => (vector_column_expr, vector_expr),
335 _ => return None,
336 };
337 Some((scan, vector_expr.clone(), vector_column_expr))
338 }
339
340 fn is_matched_vector_column_expr(
341 index_expr: &ExprImpl,
342 column_expr: &ExprImpl,
343 scan_output_col_idx: &[usize],
344 ) -> bool {
345 match (index_expr, column_expr) {
346 (ExprImpl::Literal(l1), ExprImpl::Literal(l2)) => l1 == l2,
347 (ExprImpl::InputRef(i1), ExprImpl::InputRef(i2)) => {
348 i1.index == scan_output_col_idx[i2.index]
349 }
350 (ExprImpl::FunctionCall(f1), ExprImpl::FunctionCall(f2)) => {
351 f1.func_type() == f2.func_type()
352 && f1.return_type() == f2.return_type()
353 && f1.inputs().len() == f2.inputs().len()
354 && f1.inputs().iter().zip_eq_fast(f2.inputs()).all(|(e1, e2)| {
355 Self::is_matched_vector_column_expr(e1, e2, scan_output_col_idx)
356 })
357 }
358 _ => false,
359 }
360 }
361}
362
363impl ToBatch for LogicalVectorSearch {
364 fn to_batch(&self) -> crate::error::Result<BatchPlanRef> {
365 if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan()
366 && !scan.vector_indexes().is_empty()
367 && self
368 .core
369 .ctx()
370 .session_ctx()
371 .config()
372 .enable_index_selection()
373 {
374 for index in scan.vector_indexes() {
375 if !Self::is_matched_vector_column_expr(
376 &index.vector_expr,
377 vector_column_expr,
378 scan.output_col_idx(),
379 ) {
380 continue;
381 }
382 if index.vector_index_info.distance_type() != self.core.distance_type {
383 continue;
384 }
385
386 let primary_table_cols_idx = self
387 .core
388 .output_indices
389 .iter()
390 .map(|input_idx| scan.output_col_idx()[*input_idx])
391 .collect_vec();
392 let mut covered_table_cols_idx = Vec::new();
393 let mut non_covered_table_cols_idx = Vec::new();
394 let mut primary_table_col_in_output =
395 Vec::with_capacity(primary_table_cols_idx.len());
396 for table_col_idx in &primary_table_cols_idx {
397 if let Some(covered_info_column_idx) = index
398 .primary_to_included_info_column_mapping
399 .get(table_col_idx)
400 {
401 covered_table_cols_idx.push(*table_col_idx);
402 primary_table_col_in_output.push((true, *covered_info_column_idx));
403 } else {
404 primary_table_col_in_output.push((false, non_covered_table_cols_idx.len()));
405 non_covered_table_cols_idx.push(*table_col_idx);
406 }
407 }
408 let vector_data_type = vector_expr.return_type();
409 let literal_vector_input = BatchValues::new(LogicalValues::new(
410 vec![vec![vector_expr]],
411 Schema::from_iter([Field::new("query_vector", vector_data_type)]),
412 self.core.ctx(),
413 ))
414 .into();
415 let hnsw_ef_search = match index.vector_index_info.config.as_ref().unwrap() {
416 vector_index_info::Config::Flat(_) => None,
417 vector_index_info::Config::HnswFlat(_) => Some(
418 self.core
419 .ctx()
420 .session_ctx()
421 .config()
422 .batch_hnsw_ef_search(),
423 ),
424 };
425 let info_column_desc = index.info_column_desc();
426 let core = VectorIndexLookupJoin {
427 input: literal_vector_input,
428 top_n: self.core.top_n,
429 distance_type: self.core.distance_type,
430 index_name: index.index_table.name.clone(),
431 index_table_id: index.index_table.id,
432 info_output_indices: (0..info_column_desc.len()).collect(),
433 info_column_desc,
434 include_distance: true,
435 as_of: scan.as_of(),
436 vector_column_idx: 0,
437 hnsw_ef_search,
438 ctx: self.core.ctx(),
439 };
440 let vector_search: BatchPlanRef = {
441 let vector_search: BatchPlanRef = BatchVectorSearch::with_core(core).into();
442 let unnested_array: BatchPlanRef = BatchProjectSet::new(generic::ProjectSet {
443 select_list: vec![ExprImpl::TableFunction(
444 TableFunction::new(
445 TableFunctionType::Unnest,
446 vec![ExprImpl::InputRef(
447 InputRef::new(1, vector_search.schema()[1].data_type()).into(),
448 )],
449 )?
450 .into(),
451 )],
452 input: vector_search,
453 })
454 .into();
455 let DataType::Struct(struct_type) = &unnested_array.schema()[1].data_type
456 else {
457 panic!("{:?}", unnested_array.schema()[1].data_type);
458 };
459 let unnest_struct = BatchProject::new(generic::Project::new(
460 struct_type
461 .types()
462 .enumerate()
463 .map(|(idx, data_type)| {
464 ExprImpl::FunctionCall(
465 FunctionCall::new_unchecked(
466 ExprType::Field,
467 vec![
468 ExprImpl::InputRef(
469 InputRef::new(
470 1,
471 DataType::Struct(struct_type.clone()),
472 )
473 .into(),
474 ),
475 ExprImpl::Literal(
476 Literal::new(
477 Some(ScalarImpl::Int32(idx as _)),
478 DataType::Int32,
479 )
480 .into(),
481 ),
482 ],
483 data_type.clone(),
484 )
485 .into(),
486 )
487 })
488 .collect(),
489 unnested_array,
490 ));
491 unnest_struct.into()
492 };
493 let covered_output_col_idx = covered_table_cols_idx.iter().map(|table_col_idx| {
494 index.primary_to_included_info_column_mapping[table_col_idx]
495 });
496 return Ok(if non_covered_table_cols_idx.is_empty() {
497 BatchProject::new(generic::Project::with_out_col_idx(
498 vector_search,
499 covered_output_col_idx.chain([index.included_info_columns.len()]),
500 ))
501 .into()
502 } else {
503 let mut primary_table_cols_idx = Vec::with_capacity(
504 non_covered_table_cols_idx.len() + scan.table().pk().len(),
505 );
506 primary_table_cols_idx.extend(
507 non_covered_table_cols_idx
508 .iter()
509 .cloned()
510 .chain(scan.table().pk().iter().map(|order| order.column_index)),
511 );
512 let table_scan = generic::TableScan::new(
513 primary_table_cols_idx,
514 scan.table().clone(),
515 vec![],
516 vec![],
517 self.core.input.ctx(),
518 Condition::true_cond(),
519 scan.as_of(),
520 );
521 let logical_scan = LogicalScan::from(table_scan);
522 let batch_scan = logical_scan.to_batch()?;
523 let vector_search_schema = vector_search.schema();
524 let vector_search_schema_len = vector_search_schema.len();
525 let on_condition = Condition {
526 conjunctions: index
527 .primary_key_idx_in_info_columns
528 .iter()
529 .zip_eq_debug(0..scan.table().pk().len())
530 .map(|(pk_idx_in_info_columns, pk_idx)| {
531 let batch_scan_pk_idx = vector_search_schema.len()
532 + non_covered_table_cols_idx.len()
533 + pk_idx;
534 IndexSelectionRule::create_null_safe_equal_expr(
535 *pk_idx_in_info_columns,
536 vector_search_schema[*pk_idx_in_info_columns].data_type(),
537 batch_scan_pk_idx,
538 batch_scan.schema()[non_covered_table_cols_idx.len() + pk_idx]
539 .data_type(),
540 )
541 })
542 .collect(),
543 };
544 let eq_predicate = EqJoinPredicate::create(
545 vector_search_schema.len(),
546 batch_scan.schema().len(),
547 on_condition.clone(),
548 );
549 let join = generic::Join::new(
550 vector_search,
551 batch_scan,
552 on_condition,
553 JoinType::Inner,
554 primary_table_col_in_output
555 .iter()
556 .map(|(covered, idx)| {
557 if *covered {
558 *idx
559 } else {
560 *idx + vector_search_schema_len
561 }
562 })
563 .chain([vector_search_schema_len - 1])
565 .collect(),
566 );
567 let lookup_join = LogicalJoin::gen_batch_lookup_join(
568 &logical_scan,
569 eq_predicate,
570 join,
571 false,
572 )?
573 .ok_or_else(|| {
574 ErrorCode::InternalError(
575 "failed to convert to batch lookup join".to_owned(),
576 )
577 })?;
578 lookup_join.into()
579 });
580 }
581 }
582 self.to_top_n().to_batch()
583 }
584}