risingwave_frontend/optimizer/plan_node/
logical_vector_search_lookup_join.rs1use std::sync::Arc;
16
17use pretty_xmlish::{Pretty, XmlNode};
18use risingwave_common::array::VECTOR_DISTANCE_TYPE;
19use risingwave_common::bail;
20use risingwave_common::catalog::{Field, Schema};
21use risingwave_common::types::{DataType, StructType};
22use risingwave_common::util::column_index_mapping::ColIndexMapping;
23use risingwave_pb::common::PbDistanceType;
24use risingwave_sqlparser::ast::AsOf;
25
26use crate::OptimizerContextRef;
27use crate::catalog::index_catalog::VectorIndex;
28use crate::expr::{ExprDisplay, ExprImpl};
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::plan_node::generic::{
31 GenericPlanNode, GenericPlanRef, VectorIndexLookupJoin, ensure_sorted_required_cols,
32};
33use crate::optimizer::plan_node::utils::{Distill, childless_record};
34use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *};
35use crate::optimizer::property::FunctionalDependencySet;
36use crate::utils::Condition;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39struct VectorSearchLookupJoinCore {
40 top_n: u64,
41 distance_type: PbDistanceType,
42
43 input: PlanRef,
44 input_vector_col_idx: usize,
45 lookup: PlanRef,
46 lookup_vector: ExprImpl,
47
48 lookup_output_indices: Vec<usize>,
51 include_distance: bool,
52}
53
54impl VectorSearchLookupJoinCore {
55 pub(crate) fn clone_with_input(&self, input: PlanRef, lookup: PlanRef) -> Self {
56 Self {
57 top_n: self.top_n,
58 distance_type: self.distance_type,
59 input,
60 input_vector_col_idx: self.input_vector_col_idx,
61 lookup,
62 lookup_vector: self.lookup_vector.clone(),
63 lookup_output_indices: self.lookup_output_indices.clone(),
64 include_distance: self.include_distance,
65 }
66 }
67
68 fn struct_type(&self) -> StructType {
69 StructType::row_expr_type(
70 self.lookup_output_indices
71 .iter()
72 .map(|i| {
73 let field = &self.lookup.schema().fields[*i];
74 field.data_type.clone()
75 })
76 .chain(self.include_distance.then_some(VECTOR_DISTANCE_TYPE)),
77 )
78 }
79}
80
81impl GenericPlanNode for VectorSearchLookupJoinCore {
82 fn functional_dependency(&self) -> FunctionalDependencySet {
83 FunctionalDependencySet::new(self.input.schema().len() + 1)
85 }
86
87 fn schema(&self) -> Schema {
88 let fields = self
89 .input
90 .schema()
91 .fields
92 .iter()
93 .cloned()
94 .chain([Field::new(
95 "array",
96 DataType::Struct(self.struct_type()).list(),
97 )])
98 .collect();
99
100 Schema { fields }
101 }
102
103 fn stream_key(&self) -> Option<Vec<usize>> {
104 self.input.stream_key().map(|key| key.to_vec())
105 }
106
107 fn ctx(&self) -> OptimizerContextRef {
108 self.input.ctx()
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Hash)]
113pub struct LogicalVectorSearchLookupJoin {
114 pub base: PlanBase<Logical>,
115 core: VectorSearchLookupJoinCore,
116}
117
118impl LogicalVectorSearchLookupJoin {
119 pub(crate) fn new(
120 top_n: u64,
121 distance_type: PbDistanceType,
122 input: PlanRef,
123 input_vector_col_idx: usize,
124 lookup: PlanRef,
125 lookup_vector: ExprImpl,
126 lookup_output_indices: Vec<usize>,
127 include_distance: bool,
128 ) -> Self {
129 let core = VectorSearchLookupJoinCore {
130 top_n,
131 distance_type,
132 input,
133 input_vector_col_idx,
134 lookup,
135 lookup_vector,
136 lookup_output_indices,
137 include_distance,
138 };
139 Self::with_core(core)
140 }
141
142 fn with_core(core: VectorSearchLookupJoinCore) -> Self {
143 let base = PlanBase::new_logical_with_core(&core);
144 Self { base, core }
145 }
146}
147
148impl_plan_tree_node_for_binary! { Logical, LogicalVectorSearchLookupJoin }
149
150impl PlanTreeNodeBinary<Logical> for LogicalVectorSearchLookupJoin {
151 fn left(&self) -> PlanRef {
152 self.core.input.clone()
153 }
154
155 fn right(&self) -> PlanRef {
156 self.core.lookup.clone()
157 }
158
159 fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
160 let core = self.core.clone_with_input(left, right);
161 Self::with_core(core)
162 }
163}
164
165impl Distill for LogicalVectorSearchLookupJoin {
166 fn distill<'a>(&self) -> XmlNode<'a> {
167 let verbose = self.base.ctx().is_explain_verbose();
168 let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 });
169 vec.push(("distance_type", Pretty::debug(&self.core.distance_type)));
170 vec.push(("top_n", Pretty::debug(&self.core.top_n)));
171 vec.push((
172 "input_vector",
173 Pretty::debug(&self.core.input.schema()[self.core.input_vector_col_idx]),
174 ));
175
176 vec.push((
177 "lookup_vector",
178 Pretty::debug(&ExprDisplay {
179 expr: &self.core.lookup_vector,
180 input_schema: self.core.lookup.schema(),
181 }),
182 ));
183
184 if verbose {
185 vec.push((
186 "lookup_output_columns",
187 Pretty::Array(
188 self.core
189 .lookup_output_indices
190 .iter()
191 .map(|input_idx| {
192 Pretty::debug(&self.core.lookup.schema().fields()[*input_idx])
193 })
194 .collect(),
195 ),
196 ));
197 vec.push((
198 "include_distance",
199 Pretty::debug(&self.core.include_distance),
200 ));
201 }
202
203 childless_record("LogicalVectorSearchLookupJoin", vec)
204 }
205}
206
207impl ColPrunable for LogicalVectorSearchLookupJoin {
208 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
209 let (project_exprs, mut required_cols) =
210 ensure_sorted_required_cols(required_cols, self.base.schema());
211 assert!(required_cols.is_sorted());
212 if let Some(last_col) = required_cols.last()
213 && *last_col == self.core.input.schema().len()
214 {
215 required_cols.pop();
217 let output_vector = required_cols.contains(&self.core.input_vector_col_idx);
218 if !output_vector {
219 required_cols.push(self.core.input_vector_col_idx);
221 }
222
223 let new_input = self.core.input.prune_col(&required_cols, ctx);
224 let mut core = self
225 .core
226 .clone_with_input(new_input, self.core.lookup.clone());
227
228 core.input_vector_col_idx = ColIndexMapping::with_remaining_columns(
229 &required_cols,
230 self.core.input.schema().len(),
231 )
232 .map(self.core.input_vector_col_idx);
233 let vector_search = Self::with_core(core).into();
234 let input = if output_vector {
235 vector_search
236 } else {
237 LogicalProject::with_out_col_idx(
239 vector_search,
240 (0..required_cols.len() - 1).chain([required_cols.len()]),
241 )
242 .into()
243 };
244
245 LogicalProject::create(input, project_exprs)
246 } else {
247 let input = self.core.input.prune_col(&required_cols, ctx);
249 LogicalProject::create(input, project_exprs)
250 }
251 }
252}
253
254impl ExprRewritable<Logical> for LogicalVectorSearchLookupJoin {}
255
256impl ExprVisitable for LogicalVectorSearchLookupJoin {}
257
258impl PredicatePushdown for LogicalVectorSearchLookupJoin {
259 fn predicate_pushdown(
260 &self,
261 predicate: Condition,
262 ctx: &mut PredicatePushdownContext,
263 ) -> PlanRef {
264 let input = self
266 .core
267 .input
268 .predicate_pushdown(Condition::true_cond(), ctx);
269 let lookup = self
270 .core
271 .lookup
272 .predicate_pushdown(Condition::true_cond(), ctx);
273 let core = self.core.clone_with_input(input, lookup);
274 LogicalFilter::create(Self::with_core(core).into(), predicate)
275 }
276}
277
278impl ToStream for LogicalVectorSearchLookupJoin {
279 fn logical_rewrite_for_stream(
280 &self,
281 ctx: &mut RewriteStreamContext,
282 ) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
283 if !self
284 .core
285 .input
286 .logical_rewrite_for_stream(ctx)?
287 .1
288 .is_identity()
289 {
290 bail!(
292 "LogicalVectorSearchLookupJoin does not support input that can possibly be rewritten"
293 )
294 }
295 Ok((
296 self.clone().into(),
297 ColIndexMapping::identity(self.base.schema().len()),
298 ))
299 }
300
301 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
302 if let Some(core) = self.to_vector_index_lookup_join(|plan| plan.to_stream(ctx))? {
303 if !matches!(&core.as_of, Some(AsOf::ProcessTime)) {
304 bail!("streaming vector index lookup join must be proctime temporal join");
305 }
306 return Ok(StreamVectorIndexLookupJoin::new(core)?.into());
307 }
308 bail!("LogicalVectorSearchLookupJoin should use proper vector index in streaming job")
309 }
310}
311
312impl LogicalVectorSearchLookupJoin {
313 pub(crate) fn as_index_lookup(&self) -> Option<(&Arc<VectorIndex>, Vec<usize>, Option<AsOf>)> {
314 if let Some(scan) = self.core.lookup.as_logical_scan()
315 && let Some((
316 index,
317 _covered_table_cols_idx,
318 non_covered_table_cols_idx,
319 primary_table_col_in_output,
320 )) = LogicalVectorSearch::resolve_vector_index_lookup(
321 scan,
322 &self.core.lookup_vector,
323 self.core.distance_type,
324 &self.core.lookup_output_indices,
325 )
326 && non_covered_table_cols_idx.is_empty()
327 {
328 let info_output_indices = primary_table_col_in_output
329 .iter()
330 .map(|(covered, idx_in_index_info_columns)| {
331 assert!(*covered);
332 *idx_in_index_info_columns
333 })
334 .collect();
335 Some((index, info_output_indices, scan.as_of()))
336 } else {
337 None
338 }
339 }
340}
341
342impl LogicalVectorSearchLookupJoin {
343 fn to_vector_index_lookup_join<PlanRef>(
344 &self,
345 gen_input: impl FnOnce(&LogicalPlanRef) -> Result<PlanRef>,
346 ) -> Result<Option<VectorIndexLookupJoin<PlanRef>>> {
347 if let Some((index, info_output_indices, as_of)) = self.as_index_lookup() {
348 let hnsw_ef_search =
349 index.resolve_hnsw_ef_search(&self.core.ctx().session_ctx().config());
350 let core = VectorIndexLookupJoin {
351 input: gen_input(&self.core.input)?,
352 top_n: self.core.top_n,
353 distance_type: self.core.distance_type,
354 index_name: index.index_table.name.clone(),
355 index_table_id: index.index_table.id,
356 info_column_desc: index.info_column_desc(),
357 info_output_indices,
358 include_distance: self.core.include_distance,
359 as_of,
360 vector_column_idx: self.core.input_vector_col_idx,
361 hnsw_ef_search,
362 ctx: self.core.ctx(),
363 };
364 return Ok(Some(core));
365 }
366 Ok(None)
367 }
368}
369
370impl ToBatch for LogicalVectorSearchLookupJoin {
371 fn to_batch(&self) -> Result<BatchPlanRef> {
372 if let Some(core) = self.to_vector_index_lookup_join(|plan| plan.to_batch())? {
373 return Ok(BatchVectorSearch::with_core(core).into());
374 }
375
376 bail!("no index found for BatchVectorSearchLookupJoin")
377 }
378}