risingwave_frontend/optimizer/plan_node/
logical_vector_search_lookup_join.rs1use std::sync::Arc;
16
17use pretty_xmlish::{Pretty, XmlNode};
18use risingwave_common::array::VECTOR_DISTANCE_TYPE;
19use risingwave_common::bail;
20use risingwave_common::catalog::{Field, Schema};
21use risingwave_common::types::{DataType, StructType};
22use risingwave_common::util::column_index_mapping::ColIndexMapping;
23use risingwave_pb::common::PbDistanceType;
24use risingwave_sqlparser::ast::AsOf;
25
26use crate::OptimizerContextRef;
27use crate::catalog::index_catalog::VectorIndex;
28use crate::expr::{ExprDisplay, ExprImpl};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::{
31 GenericPlanNode, GenericPlanRef, VectorIndexLookupJoin, ensure_sorted_required_cols,
32};
33use crate::optimizer::plan_node::utils::{Distill, childless_record};
34use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
35use crate::optimizer::property::FunctionalDependencySet;
36use crate::utils::Condition;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39struct VectorSearchLookupJoinCore {
40 top_n: u64,
41 distance_type: PbDistanceType,
42
43 input: PlanRef,
44 input_vector_col_idx: usize,
45 lookup: PlanRef,
46 lookup_vector: ExprImpl,
47
48 lookup_output_indices: Vec<usize>,
51 include_distance: bool,
52}
53
54impl VectorSearchLookupJoinCore {
55 pub(crate) fn clone_with_input(&self, input: PlanRef, lookup: PlanRef) -> Self {
56 Self {
57 top_n: self.top_n,
58 distance_type: self.distance_type,
59 input,
60 input_vector_col_idx: self.input_vector_col_idx,
61 lookup,
62 lookup_vector: self.lookup_vector.clone(),
63 lookup_output_indices: self.lookup_output_indices.clone(),
64 include_distance: self.include_distance,
65 }
66 }
67
68 fn struct_type(&self) -> StructType {
69 StructType::new(
70 self.lookup_output_indices
71 .iter()
72 .map(|i| {
73 let field = &self.lookup.schema().fields[*i];
74 (field.name.clone(), field.data_type.clone())
75 })
76 .chain(
77 self.include_distance
78 .then(|| ("vector_distance".to_owned(), VECTOR_DISTANCE_TYPE)),
79 ),
80 )
81 }
82}
83
84impl GenericPlanNode for VectorSearchLookupJoinCore {
85 fn functional_dependency(&self) -> FunctionalDependencySet {
86 FunctionalDependencySet::new(self.input.schema().len() + 1)
88 }
89
90 fn schema(&self) -> Schema {
91 let fields = self
92 .input
93 .schema()
94 .fields
95 .iter()
96 .cloned()
97 .chain([Field::new(
98 "array",
99 DataType::Struct(self.struct_type()).list(),
100 )])
101 .collect();
102
103 Schema { fields }
104 }
105
106 fn stream_key(&self) -> Option<Vec<usize>> {
107 self.input.stream_key().map(|key| key.to_vec())
108 }
109
110 fn ctx(&self) -> OptimizerContextRef {
111 self.input.ctx()
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct LogicalVectorSearchLookupJoin {
117 pub base: PlanBase<Logical>,
118 core: VectorSearchLookupJoinCore,
119}
120
121impl LogicalVectorSearchLookupJoin {
122 pub(crate) fn new(
123 top_n: u64,
124 distance_type: PbDistanceType,
125 input: PlanRef,
126 input_vector_col_idx: usize,
127 lookup: PlanRef,
128 lookup_vector: ExprImpl,
129 lookup_output_indices: Vec<usize>,
130 include_distance: bool,
131 ) -> Self {
132 let core = VectorSearchLookupJoinCore {
133 top_n,
134 distance_type,
135 input,
136 input_vector_col_idx,
137 lookup,
138 lookup_vector,
139 lookup_output_indices,
140 include_distance,
141 };
142 Self::with_core(core)
143 }
144
145 fn with_core(core: VectorSearchLookupJoinCore) -> Self {
146 let base = PlanBase::new_logical_with_core(&core);
147 Self { base, core }
148 }
149}
150
151impl_plan_tree_node_for_binary! { Logical, LogicalVectorSearchLookupJoin }
152
153impl PlanTreeNodeBinary<Logical> for LogicalVectorSearchLookupJoin {
154 fn left(&self) -> PlanRef {
155 self.core.input.clone()
156 }
157
158 fn right(&self) -> PlanRef {
159 self.core.lookup.clone()
160 }
161
162 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
163 let core = self.core.clone_with_input(left, right);
164 Self::with_core(core)
165 }
166}
167
168impl Distill for LogicalVectorSearchLookupJoin {
169 fn distill<'a>(&self) -> XmlNode<'a> {
170 let verbose = self.base.ctx().is_explain_verbose();
171 let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
172 vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
173 vec.push(("top_n", Pretty::debug(&self.core.top_n)));
174 vec.push((
175 "input_vector",
176 Pretty::debug(&self.core.input.schema()[self.core.input_vector_col_idx]),
177 ));
178
179 vec.push((
180 "lookup_vector",
181 Pretty::debug(&ExprDisplay {
182 expr: &self.core.lookup_vector,
183 input_schema: self.core.lookup.schema(),
184 }),
185 ));
186
187 if verbose {
188 vec.push((
189 "lookup_output_columns",
190 Pretty::Array(
191 self.core
192 .lookup_output_indices
193 .iter()
194 .map(|input_idx| {
195 Pretty::debug(&self.core.lookup.schema().fields()[*input_idx])
196 })
197 .collect(),
198 ),
199 ));
200 vec.push((
201 "include_distance",
202 Pretty::debug(&self.core.include_distance),
203 ));
204 }
205
206 childless_record("LogicalVectorSearchLookupJoin", vec)
207 }
208}
209
210impl ColPrunable for LogicalVectorSearchLookupJoin {
211 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
212 let (project_exprs, mut required_cols) =
213 ensure_sorted_required_cols(required_cols, self.base.schema());
214 assert!(required_cols.is_sorted());
215 if let Some(last_col) = required_cols.last()
216 && *last_col == self.core.input.schema().len()
217 {
218 required_cols.pop();
220 let output_vector = required_cols.contains(&self.core.input_vector_col_idx);
221 if !output_vector {
222 required_cols.push(self.core.input_vector_col_idx);
224 }
225
226 let new_input = self.core.input.prune_col(&required_cols, ctx);
227 let mut core = self
228 .core
229 .clone_with_input(new_input, self.core.lookup.clone());
230
231 core.input_vector_col_idx = ColIndexMapping::with_remaining_columns(
232 &required_cols,
233 self.core.input.schema().len(),
234 )
235 .map(self.core.input_vector_col_idx);
236 let vector_search = Self::with_core(core).into();
237 let input = if output_vector {
238 vector_search
239 } else {
240 LogicalProject::with_out_col_idx(
242 vector_search,
243 (0..required_cols.len() - 1).chain([required_cols.len()]),
244 )
245 .into()
246 };
247
248 LogicalProject::create(input, project_exprs)
249 } else {
250 let input = self.core.input.prune_col(&required_cols, ctx);
252 LogicalProject::create(input, project_exprs)
253 }
254 }
255}
256
257impl ExprRewritable<Logical> for LogicalVectorSearchLookupJoin {}
258
259impl ExprVisitable for LogicalVectorSearchLookupJoin {}
260
261impl PredicatePushdown for LogicalVectorSearchLookupJoin {
262 fn predicate_pushdown(
263 &self,
264 predicate: Condition,
265 ctx: &mut PredicatePushdownContext,
266 ) -> PlanRef {
267 let input = self
269 .core
270 .input
271 .predicate_pushdown(Condition::true_cond(), ctx);
272 let lookup = self
273 .core
274 .lookup
275 .predicate_pushdown(Condition::true_cond(), ctx);
276 let core = self.core.clone_with_input(input, lookup);
277 LogicalFilter::create(Self::with_core(core).into(), predicate)
278 }
279}
280
281impl ToStream for LogicalVectorSearchLookupJoin {
282 fn logical_rewrite_for_stream(
283 &self,
284 _ctx: &mut RewriteStreamContext,
285 ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
286 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
287 }
288
289 fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
290 bail!("LogicalVectorSearch can only for batch plan, not stream plan");
291 }
292}
293
294impl LogicalVectorSearchLookupJoin {
295 pub(crate) fn as_index_lookup(&self) -> Option<(&Arc<VectorIndex>, Vec<usize>, Option<AsOf>)> {
296 if let Some(scan) = self.core.lookup.as_logical_scan()
297 && let Some((
298 index,
299 _covered_table_cols_idx,
300 non_covered_table_cols_idx,
301 primary_table_col_in_output,
302 )) = LogicalVectorSearch::resolve_vector_index_lookup(
303 scan,
304 &self.core.lookup_vector,
305 self.core.distance_type,
306 &self.core.lookup_output_indices,
307 )
308 && non_covered_table_cols_idx.is_empty()
309 {
310 let info_output_indices = primary_table_col_in_output
311 .iter()
312 .map(|(covered, idx_in_index_info_columns)| {
313 assert!(*covered);
314 *idx_in_index_info_columns
315 })
316 .collect();
317 Some((index, info_output_indices, scan.as_of()))
318 } else {
319 None
320 }
321 }
322}
323
324impl ToBatch for LogicalVectorSearchLookupJoin {
325 fn to_batch(&self) -> Result<BatchPlanRef> {
326 if let Some((index, info_output_indices, as_of)) = self.as_index_lookup() {
327 let hnsw_ef_search =
328 index.resolve_hnsw_ef_search(&self.core.ctx().session_ctx().config());
329 let core = VectorIndexLookupJoin {
330 input: self.core.input.to_batch()?,
331 top_n: self.core.top_n,
332 distance_type: self.core.distance_type,
333 index_name: index.index_table.name.clone(),
334 index_table_id: index.index_table.id,
335 info_column_desc: index.info_column_desc(),
336 info_output_indices,
337 include_distance: self.core.include_distance,
338 as_of,
339 vector_column_idx: self.core.input_vector_col_idx,
340 hnsw_ef_search,
341 ctx: self.core.ctx(),
342 };
343 return Ok(BatchVectorSearch::with_core(core).into());
344 }
345
346 bail!("no index found for BatchVectorSearchLookupJoin")
347 }
348}