risingwave_frontend/optimizer/rule/
top_n_to_vector_search_rule.rs1use std::assert_matches::assert_matches;
16
17use risingwave_common::types::DataType;
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_pb::common::PbDistanceType;
20
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, InputRef};
22use crate::optimizer::LogicalPlanRef;
23use crate::optimizer::plan_node::generic::TopNLimit;
24use crate::optimizer::plan_node::{
25 LogicalPlanRef as PlanRef, LogicalProject, LogicalTopN, LogicalVectorSearch, PlanTreeNodeUnary,
26};
27use crate::optimizer::rule::prelude::*;
28use crate::optimizer::rule::{BoxedRule, ProjectMergeRule, Rule};
29
30pub struct TopNToVectorSearchRule;
31
32impl TopNToVectorSearchRule {
33 pub fn create() -> BoxedRule<Logical> {
34 Box::new(TopNToVectorSearchRule)
35 }
36}
37
38fn merge_consecutive_projections(input: LogicalPlanRef) -> Option<(Vec<ExprImpl>, LogicalPlanRef)> {
39 let projection = input.as_logical_project()?;
40 let mut exprs = projection.exprs().clone();
41 let mut input = projection.input();
42 while let Some(projection) = input.as_logical_project() {
43 exprs = ProjectMergeRule::merge_project_exprs(&exprs, projection.exprs(), false)?;
44 input = projection.input();
45 }
46 Some((exprs, input))
47}
48
49impl TopNToVectorSearchRule {
50 fn resolve_vector_search(top_n: &LogicalTopN) -> Option<(LogicalVectorSearch, Vec<ExprImpl>)> {
51 if !top_n.group_key().is_empty() {
52 return None;
54 }
55 if top_n.offset() > 0 {
56 return None;
57 }
58 let TopNLimit::Simple(limit) = top_n.limit_attr() else {
59 return None;
61 };
62 let [order]: &[ColumnOrder; 1] = top_n
64 .topn_order()
65 .column_orders
66 .as_slice()
67 .try_into()
68 .ok()?;
69 if order.order_type.is_descending() || order.order_type.nulls_are_smallest() {
70 return None;
72 }
73
74 let (exprs, projection_input) = merge_consecutive_projections(top_n.input())?;
76
77 let order_expr = &exprs[order.column_index];
78 let ExprImpl::FunctionCall(call) = order_expr else {
79 return None;
80 };
81 let (call, distance_type) = match call.func_type() {
82 ExprType::L1Distance => (call, PbDistanceType::L1),
83 ExprType::L2Distance => (call, PbDistanceType::L2Sqr),
84 ExprType::CosineDistance => (call, PbDistanceType::Cosine),
85 ExprType::Neg => {
86 let [neg_input] = call.inputs() else {
87 return None;
88 };
89 let ExprImpl::FunctionCall(call) = neg_input else {
90 return None;
91 };
92 if let ExprType::InnerProduct = call.func_type() {
93 (call, PbDistanceType::InnerProduct)
94 } else {
95 return None;
96 }
97 }
98 _ => {
99 return None;
100 }
101 };
102 assert_eq!(
103 call.inputs().len(),
104 2,
105 "vector distance function should have exactly two arguments",
106 );
107
108 let [left, right]: &[_; 2] = call.inputs().try_into().unwrap();
109 assert_matches!(left.return_type(), DataType::Vector(_));
110 assert_matches!(right.return_type(), DataType::Vector(_));
111
112 let vector_search = LogicalVectorSearch::new(
113 limit,
114 distance_type,
115 left.clone(),
116 right.clone(),
117 (0..projection_input.schema().len()).collect(),
118 projection_input.clone(),
119 );
120 let mut i2o_mapping = vector_search.i2o_mapping();
121 let mut output_exprs = Vec::with_capacity(exprs.len());
122 for expr in &exprs[0..order.column_index] {
123 output_exprs.push(i2o_mapping.rewrite_expr(expr.clone()));
124 }
125 output_exprs.push(ExprImpl::InputRef(
126 InputRef {
127 index: projection_input.schema().len(),
128 data_type: DataType::Float64,
129 }
130 .into(),
131 ));
132 for expr in &exprs[order.column_index + 1..exprs.len()] {
133 output_exprs.push(i2o_mapping.rewrite_expr(expr.clone()));
134 }
135 Some((vector_search, output_exprs))
136 }
137}
138
139impl Rule<Logical> for TopNToVectorSearchRule {
150 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
151 let top_n = plan.as_logical_top_n()?;
152 let (vector_search, project_exprs) = Self::resolve_vector_search(top_n)?;
153 Some(LogicalProject::create(vector_search.into(), project_exprs))
154 }
155}