risingwave_frontend/optimizer/rule/
top_n_to_vector_search_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::assert_matches::assert_matches;
16
17use risingwave_common::types::DataType;
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_pb::common::PbDistanceType;
20
21use crate::expr::{Expr, ExprImpl, ExprType, InputRef};
22use crate::optimizer::LogicalPlanRef;
23use crate::optimizer::plan_node::generic::{GenericPlanRef, TopNLimit};
24use crate::optimizer::plan_node::{
25    LogicalPlanRef as PlanRef, LogicalProject, LogicalTopN, LogicalVectorSearch, PlanTreeNodeUnary,
26};
27use crate::optimizer::rule::prelude::*;
28use crate::optimizer::rule::{BoxedRule, ProjectMergeRule, Rule};
29
30pub struct TopNToVectorSearchRule;
31
32impl TopNToVectorSearchRule {
33    pub fn create() -> BoxedRule<Logical> {
34        Box::new(TopNToVectorSearchRule)
35    }
36}
37
38fn merge_consecutive_projections(input: LogicalPlanRef) -> Option<(Vec<ExprImpl>, LogicalPlanRef)> {
39    let projection = input.as_logical_project()?;
40    let mut exprs = projection.exprs().clone();
41    let mut input = projection.input();
42    while let Some(projection) = input.as_logical_project() {
43        exprs = ProjectMergeRule::merge_project_exprs(&exprs, projection.exprs(), false)?;
44        input = projection.input();
45    }
46    Some((exprs, input))
47}
48
49impl TopNToVectorSearchRule {
50    #[expect(clippy::type_complexity)]
51    pub(super) fn resolve_vector_search(
52        top_n: &LogicalTopN,
53    ) -> Option<(
54        (u64, PbDistanceType, ExprImpl, ExprImpl, PlanRef),
55        Vec<ExprImpl>,
56    )> {
57        if !top_n.group_key().is_empty() {
58            // vector search applies for only singleton top n
59            return None;
60        }
61        if top_n.offset() > 0 {
62            return None;
63        }
64        let TopNLimit::Simple(limit) = top_n.limit_attr() else {
65            // vector index applies for only simple top n
66            return None;
67        };
68        // vector index applies for only top n with one order column
69        let [order]: &[ColumnOrder; 1] = top_n
70            .topn_order()
71            .column_orders
72            .as_slice()
73            .try_into()
74            .ok()?;
75        if order.order_type.is_descending() || order.order_type.nulls_are_smallest() {
76            // vector index applies for only ascending order with nulls last
77            return None;
78        }
79
80        // TODO: may merge the projections in a finer way so as not to break potential common sub expr.
81        let (exprs, projection_input) = merge_consecutive_projections(top_n.input())?;
82
83        let order_expr = &exprs[order.column_index];
84        let ExprImpl::FunctionCall(call) = order_expr else {
85            return None;
86        };
87        let (call, distance_type) = match call.func_type() {
88            ExprType::L1Distance => (call, PbDistanceType::L1),
89            ExprType::L2Distance => (call, PbDistanceType::L2Sqr),
90            ExprType::CosineDistance => (call, PbDistanceType::Cosine),
91            ExprType::Neg => {
92                let [neg_input] = call.inputs() else {
93                    return None;
94                };
95                let ExprImpl::FunctionCall(call) = neg_input else {
96                    return None;
97                };
98                if let ExprType::InnerProduct = call.func_type() {
99                    (call, PbDistanceType::InnerProduct)
100                } else {
101                    return None;
102                }
103            }
104            _ => {
105                return None;
106            }
107        };
108        assert_eq!(
109            call.inputs().len(),
110            2,
111            "vector distance function should have exactly two arguments",
112        );
113
114        let [left, right]: &[_; 2] = call.inputs().try_into().unwrap();
115        assert_matches!(left.return_type(), DataType::Vector(_));
116        assert_matches!(right.return_type(), DataType::Vector(_));
117
118        let mut output_exprs = Vec::with_capacity(exprs.len());
119        for expr in &exprs[0..order.column_index] {
120            output_exprs.push(expr.clone());
121        }
122        output_exprs.push(ExprImpl::InputRef(
123            InputRef {
124                index: projection_input.schema().len(),
125                data_type: DataType::Float64,
126            }
127            .into(),
128        ));
129        for expr in &exprs[order.column_index + 1..exprs.len()] {
130            output_exprs.push(expr.clone());
131        }
132        Some((
133            (
134                limit,
135                distance_type,
136                left.clone(),
137                right.clone(),
138                projection_input,
139            ),
140            output_exprs,
141        ))
142    }
143}
144
145/// This rule converts the following TopN pattern to `LogicalVectorSearch`
146/// ```text
147///     LogicalTopN { order: [$expr1 ASC], limit: TOP_N, offset: 0 }
148///       └─LogicalProject { exprs: [VectorDistanceFunc(vector_expr1, vector_expr2) as $expr1, other_exprs...] }
149/// ```
150/// to
151/// ```text
152///     LogicalProject { exprs: [other_exprs... + distance_column] }
153///       └─LogicalVectorSearch { distance_type: `PbDistanceType`, top_n: TOP_N, left: vector_expr1, right: vector_expr2, output_columns: [...] }
154/// ```
155impl Rule<Logical> for TopNToVectorSearchRule {
156    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
157        let top_n = plan.as_logical_top_n()?;
158        let ((top_n, distance_type, left, right, input), project_exprs) =
159            TopNToVectorSearchRule::resolve_vector_search(top_n)?;
160        let vector_search = LogicalVectorSearch::new(top_n, distance_type, left, right, input);
161        Some(LogicalProject::create(vector_search.into(), project_exprs))
162    }
163}