risingwave_frontend/optimizer/rule/
top_n_to_vector_search_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::assert_matches::assert_matches;
16
17use risingwave_common::types::DataType;
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_pb::common::PbDistanceType;
20
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, InputRef};
22use crate::optimizer::LogicalPlanRef;
23use crate::optimizer::plan_node::generic::TopNLimit;
24use crate::optimizer::plan_node::{
25    LogicalPlanRef as PlanRef, LogicalProject, LogicalTopN, LogicalVectorSearch, PlanTreeNodeUnary,
26};
27use crate::optimizer::rule::prelude::*;
28use crate::optimizer::rule::{BoxedRule, ProjectMergeRule, Rule};
29
30pub struct TopNToVectorSearchRule;
31
32impl TopNToVectorSearchRule {
33    pub fn create() -> BoxedRule<Logical> {
34        Box::new(TopNToVectorSearchRule)
35    }
36}
37
38fn merge_consecutive_projections(input: LogicalPlanRef) -> Option<(Vec<ExprImpl>, LogicalPlanRef)> {
39    let projection = input.as_logical_project()?;
40    let mut exprs = projection.exprs().clone();
41    let mut input = projection.input();
42    while let Some(projection) = input.as_logical_project() {
43        exprs = ProjectMergeRule::merge_project_exprs(&exprs, projection.exprs(), false)?;
44        input = projection.input();
45    }
46    Some((exprs, input))
47}
48
49impl TopNToVectorSearchRule {
50    fn resolve_vector_search(top_n: &LogicalTopN) -> Option<(LogicalVectorSearch, Vec<ExprImpl>)> {
51        if !top_n.group_key().is_empty() {
52            // vector search applies for only singleton top n
53            return None;
54        }
55        if top_n.offset() > 0 {
56            return None;
57        }
58        let TopNLimit::Simple(limit) = top_n.limit_attr() else {
59            // vector index applies for only simple top n
60            return None;
61        };
62        // vector index applies for only top n with one order column
63        let [order]: &[ColumnOrder; 1] = top_n
64            .topn_order()
65            .column_orders
66            .as_slice()
67            .try_into()
68            .ok()?;
69        if order.order_type.is_descending() || order.order_type.nulls_are_smallest() {
70            // vector index applies for only ascending order with nulls last
71            return None;
72        }
73
74        // TODO: may merge the projections in a finer way so as not to break potential common sub expr.
75        let (exprs, projection_input) = merge_consecutive_projections(top_n.input())?;
76
77        let order_expr = &exprs[order.column_index];
78        let ExprImpl::FunctionCall(call) = order_expr else {
79            return None;
80        };
81        let (call, distance_type) = match call.func_type() {
82            ExprType::L1Distance => (call, PbDistanceType::L1),
83            ExprType::L2Distance => (call, PbDistanceType::L2Sqr),
84            ExprType::CosineDistance => (call, PbDistanceType::Cosine),
85            ExprType::Neg => {
86                let [neg_input] = call.inputs() else {
87                    return None;
88                };
89                let ExprImpl::FunctionCall(call) = neg_input else {
90                    return None;
91                };
92                if let ExprType::InnerProduct = call.func_type() {
93                    (call, PbDistanceType::InnerProduct)
94                } else {
95                    return None;
96                }
97            }
98            _ => {
99                return None;
100            }
101        };
102        assert_eq!(
103            call.inputs().len(),
104            2,
105            "vector distance function should have exactly two arguments",
106        );
107
108        let [left, right]: &[_; 2] = call.inputs().try_into().unwrap();
109        assert_matches!(left.return_type(), DataType::Vector(_));
110        assert_matches!(right.return_type(), DataType::Vector(_));
111
112        let vector_search = LogicalVectorSearch::new(
113            limit,
114            distance_type,
115            left.clone(),
116            right.clone(),
117            (0..projection_input.schema().len()).collect(),
118            projection_input.clone(),
119        );
120        let mut i2o_mapping = vector_search.i2o_mapping();
121        let mut output_exprs = Vec::with_capacity(exprs.len());
122        for expr in &exprs[0..order.column_index] {
123            output_exprs.push(i2o_mapping.rewrite_expr(expr.clone()));
124        }
125        output_exprs.push(ExprImpl::InputRef(
126            InputRef {
127                index: projection_input.schema().len(),
128                data_type: DataType::Float64,
129            }
130            .into(),
131        ));
132        for expr in &exprs[order.column_index + 1..exprs.len()] {
133            output_exprs.push(i2o_mapping.rewrite_expr(expr.clone()));
134        }
135        Some((vector_search, output_exprs))
136    }
137}
138
139/// This rule converts the following TopN pattern to `LogicalVectorSearch`
140/// ```text
141///     LogicalTopN { order: [$expr1 ASC], limit: TOP_N, offset: 0 }
142///       └─LogicalProject { exprs: [VectorDistanceFunc(vector_expr1, vector_expr2) as $expr1, other_exprs...] }
143/// ```
144/// to
145/// ```text
146///     LogicalProject { exprs: [other_exprs... + distance_column] }
147///       └─LogicalVectorSearch { distance_type: `PbDistanceType`, top_n: TOP_N, left: vector_expr1, right: vector_expr2, output_columns: [...] }
148/// ```
149impl Rule<Logical> for TopNToVectorSearchRule {
150    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
151        let top_n = plan.as_logical_top_n()?;
152        let (vector_search, project_exprs) = Self::resolve_vector_search(top_n)?;
153        Some(LogicalProject::create(vector_search.into(), project_exprs))
154    }
155}