risingwave_frontend/optimizer/rule/
top_n_to_vector_search_rule.rs1use std::assert_matches::assert_matches;
16
17use risingwave_common::types::DataType;
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_pb::common::PbDistanceType;
20
21use crate::expr::{Expr, ExprImpl, ExprType, InputRef};
22use crate::optimizer::LogicalPlanRef;
23use crate::optimizer::plan_node::generic::{GenericPlanRef, TopNLimit};
24use crate::optimizer::plan_node::{
25 LogicalPlanRef as PlanRef, LogicalProject, LogicalTopN, LogicalVectorSearch, PlanTreeNodeUnary,
26};
27use crate::optimizer::rule::prelude::*;
28use crate::optimizer::rule::{BoxedRule, ProjectMergeRule, Rule};
29
30pub struct TopNToVectorSearchRule;
31
32impl TopNToVectorSearchRule {
33 pub fn create() -> BoxedRule<Logical> {
34 Box::new(TopNToVectorSearchRule)
35 }
36}
37
38fn merge_consecutive_projections(input: LogicalPlanRef) -> Option<(Vec<ExprImpl>, LogicalPlanRef)> {
39 let projection = input.as_logical_project()?;
40 let mut exprs = projection.exprs().clone();
41 let mut input = projection.input();
42 while let Some(projection) = input.as_logical_project() {
43 exprs = ProjectMergeRule::merge_project_exprs(&exprs, projection.exprs(), false)?;
44 input = projection.input();
45 }
46 Some((exprs, input))
47}
48
49impl TopNToVectorSearchRule {
50 #[expect(clippy::type_complexity)]
51 pub(super) fn resolve_vector_search(
52 top_n: &LogicalTopN,
53 ) -> Option<(
54 (u64, PbDistanceType, ExprImpl, ExprImpl, PlanRef),
55 Vec<ExprImpl>,
56 )> {
57 if !top_n.group_key().is_empty() {
58 return None;
60 }
61 if top_n.offset() > 0 {
62 return None;
63 }
64 let TopNLimit::Simple(limit) = top_n.limit_attr() else {
65 return None;
67 };
68 let [order]: &[ColumnOrder; 1] = top_n
70 .topn_order()
71 .column_orders
72 .as_slice()
73 .try_into()
74 .ok()?;
75 if order.order_type.is_descending() || order.order_type.nulls_are_smallest() {
76 return None;
78 }
79
80 let (exprs, projection_input) = merge_consecutive_projections(top_n.input())?;
82
83 let order_expr = &exprs[order.column_index];
84 let ExprImpl::FunctionCall(call) = order_expr else {
85 return None;
86 };
87 let (call, distance_type) = match call.func_type() {
88 ExprType::L1Distance => (call, PbDistanceType::L1),
89 ExprType::L2Distance => (call, PbDistanceType::L2Sqr),
90 ExprType::CosineDistance => (call, PbDistanceType::Cosine),
91 ExprType::Neg => {
92 let [neg_input] = call.inputs() else {
93 return None;
94 };
95 let ExprImpl::FunctionCall(call) = neg_input else {
96 return None;
97 };
98 if let ExprType::InnerProduct = call.func_type() {
99 (call, PbDistanceType::InnerProduct)
100 } else {
101 return None;
102 }
103 }
104 _ => {
105 return None;
106 }
107 };
108 assert_eq!(
109 call.inputs().len(),
110 2,
111 "vector distance function should have exactly two arguments",
112 );
113
114 let [left, right]: &[_; 2] = call.inputs().try_into().unwrap();
115 assert_matches!(left.return_type(), DataType::Vector(_));
116 assert_matches!(right.return_type(), DataType::Vector(_));
117
118 let mut output_exprs = Vec::with_capacity(exprs.len());
119 for expr in &exprs[0..order.column_index] {
120 output_exprs.push(expr.clone());
121 }
122 output_exprs.push(ExprImpl::InputRef(
123 InputRef {
124 index: projection_input.schema().len(),
125 data_type: DataType::Float64,
126 }
127 .into(),
128 ));
129 for expr in &exprs[order.column_index + 1..exprs.len()] {
130 output_exprs.push(expr.clone());
131 }
132 Some((
133 (
134 limit,
135 distance_type,
136 left.clone(),
137 right.clone(),
138 projection_input,
139 ),
140 output_exprs,
141 ))
142 }
143}
144
145impl Rule<Logical> for TopNToVectorSearchRule {
156 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
157 let top_n = plan.as_logical_top_n()?;
158 let ((top_n, distance_type, left, right, input), project_exprs) =
159 TopNToVectorSearchRule::resolve_vector_search(top_n)?;
160 let vector_search = LogicalVectorSearch::new(top_n, distance_type, left, right, input);
161 Some(LogicalProject::create(vector_search.into(), project_exprs))
162 }
163}