risingwave_frontend/optimizer/rule/
rewrite_like_expr_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::str::from_utf8;
17
18use risingwave_common::types::{DataType, ScalarImpl};
19
20use super::{BoxedRule, Rule};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, Literal};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter};
24
25/// `RewriteLikeExprRule` rewrites like expression, so that it can benefit from index selection.
26/// col like 'ABC' => col = 'ABC'
27/// col like 'ABC%' => col >= 'ABC' and col < 'ABD'
28/// col like 'ABC%E' => col >= 'ABC' and col < 'ABD' and col like 'ABC%E'
29pub struct RewriteLikeExprRule {}
30impl Rule for RewriteLikeExprRule {
31    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
32        let filter: &LogicalFilter = plan.as_logical_filter()?;
33        if filter.predicate().conjunctions.iter().any(|expr| {
34            let mut has_like = HasLikeExprVisitor { has: false };
35            has_like.visit_expr(expr);
36            has_like.has
37        }) {
38            let mut rewriter = LikeExprRewriter {};
39            Some(filter.rewrite_exprs(&mut rewriter))
40        } else {
41            None
42        }
43    }
44}
45
46struct HasLikeExprVisitor {
47    has: bool,
48}
49
50impl ExprVisitor for HasLikeExprVisitor {
51    fn visit_function_call(&mut self, func_call: &FunctionCall) {
52        if func_call.func_type() == ExprType::Like
53            && let (_, ExprImpl::InputRef(_), ExprImpl::Literal(_)) =
54                func_call.clone().decompose_as_binary()
55        {
56            self.has = true;
57        } else {
58            func_call
59                .inputs()
60                .iter()
61                .for_each(|expr| self.visit_expr(expr));
62        }
63    }
64}
65
66struct LikeExprRewriter {}
67
68impl LikeExprRewriter {
69    fn cal_index_and_unescape(bytes: &[u8]) -> (Option<usize>, Option<usize>, Vec<u8>) {
70        // The pattern without escape character.
71        let mut unescaped_bytes = vec![];
72        // The idx of `_` in the `unescaped_bytes`.
73        let mut char_wildcard_idx: Option<usize> = None;
74        // The idx of `%` in the `unescaped_bytes`.
75        let mut str_wildcard_idx: Option<usize> = None;
76
77        let mut in_escape = false;
78        const ESCAPE: u8 = b'\\';
79        for &c in bytes {
80            if !in_escape && c == ESCAPE {
81                in_escape = true;
82                continue;
83            }
84            if in_escape {
85                in_escape = false;
86                unescaped_bytes.push(c);
87                continue;
88            }
89            unescaped_bytes.push(c);
90            if c == b'_' {
91                char_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
92            } else if c == b'%' {
93                str_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
94            }
95        }
96        // Note: we only support `\\` as the escape character now, and it can't be positioned at the
97        // end of string, which will be banned by parser.
98        assert!(!in_escape);
99        (char_wildcard_idx, str_wildcard_idx, unescaped_bytes)
100    }
101}
102
103impl ExprRewriter for LikeExprRewriter {
104    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
105        let (func_type, inputs, ret) = func_call.decompose();
106        let inputs: Vec<ExprImpl> = inputs
107            .into_iter()
108            .map(|expr| self.rewrite_expr(expr))
109            .collect();
110        let func_call = FunctionCall::new_unchecked(func_type, inputs, ret.clone());
111
112        if func_call.func_type() != ExprType::Like {
113            return func_call.into();
114        }
115
116        let (_, ExprImpl::InputRef(x), ExprImpl::Literal(y)) =
117            func_call.clone().decompose_as_binary()
118        else {
119            return func_call.into();
120        };
121
122        if y.return_type() != DataType::Varchar {
123            return func_call.into();
124        }
125
126        let data = y.get_data();
127        let Some(ScalarImpl::Utf8(data)) = data else {
128            return func_call.into();
129        };
130
131        let bytes = data.as_bytes();
132        let len = bytes.len();
133
134        let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
135            Self::cal_index_and_unescape(bytes);
136
137        let idx = match (char_wildcard_idx, str_wildcard_idx) {
138            (Some(a), Some(b)) => min(a, b),
139            (Some(idx), None) => idx,
140            (None, Some(idx)) => idx,
141            (None, None) => {
142                let Ok(unescaped_y) = String::from_utf8(unescaped_bytes) else {
143                    // FIXME: We should definitely treat the argument as UTF-8 string instead of bytes, but currently, we just fallback here.
144                    return func_call.into();
145                };
146                let inputs = vec![
147                    ExprImpl::InputRef(x),
148                    ExprImpl::literal_varchar(unescaped_y),
149                ];
150                let func_call = FunctionCall::new_unchecked(ExprType::Equal, inputs, ret);
151                return func_call.into();
152            }
153        };
154
155        if idx == 0 {
156            return func_call.into();
157        }
158
159        let (low, high) = {
160            let low = unescaped_bytes[0..idx].to_owned();
161            if low[idx - 1] == 255 {
162                return func_call.into();
163            }
164            let mut high = low.clone();
165            high[idx - 1] += 1;
166            match (from_utf8(&low), from_utf8(&high)) {
167                (Ok(low), Ok(high)) => (low.to_owned(), high.to_owned()),
168                _ => {
169                    return func_call.into();
170                }
171            }
172        };
173
174        let between = FunctionCall::new_unchecked(
175            ExprType::And,
176            vec![
177                FunctionCall::new_unchecked(
178                    ExprType::GreaterThanOrEqual,
179                    vec![
180                        ExprImpl::InputRef(x.clone()),
181                        ExprImpl::Literal(
182                            Literal::new(Some(ScalarImpl::Utf8(low.into())), DataType::Varchar)
183                                .into(),
184                        ),
185                    ],
186                    DataType::Boolean,
187                )
188                .into(),
189                FunctionCall::new_unchecked(
190                    ExprType::LessThan,
191                    vec![
192                        ExprImpl::InputRef(x),
193                        ExprImpl::Literal(
194                            Literal::new(Some(ScalarImpl::Utf8(high.into())), DataType::Varchar)
195                                .into(),
196                        ),
197                    ],
198                    DataType::Boolean,
199                )
200                .into(),
201            ],
202            DataType::Boolean,
203        );
204
205        if idx == len - 1 {
206            between.into()
207        } else {
208            FunctionCall::new_unchecked(
209                ExprType::And,
210                vec![between.into(), func_call.into()],
211                DataType::Boolean,
212            )
213            .into()
214        }
215    }
216}
217
218impl RewriteLikeExprRule {
219    pub fn create() -> BoxedRule {
220        Box::new(RewriteLikeExprRule {})
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    #[test]
227    fn test_cal_index_and_unescape() {
228        #[expect(clippy::type_complexity, reason = "in testcase")]
229        let testcases: [(&str, (Option<usize>, Option<usize>, &str)); 7] = [
230            ("testname", (None, None, "testname")),
231            ("test_name", (Some(4), None, "test_name")),
232            ("test_name_2", (Some(4), None, "test_name_2")),
233            ("test%name", (None, Some(4), "test%name")),
234            (r"test\_name", (None, None, "test_name")),
235            (r"test\_name_2", (Some(9), None, "test_name_2")),
236            (r"test\\_name_2", (Some(5), None, r"test\_name_2")),
237        ];
238
239        for (pattern, (c, s, ub)) in testcases {
240            let input = pattern.as_bytes();
241            let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
242                super::LikeExprRewriter::cal_index_and_unescape(input);
243            assert_eq!(char_wildcard_idx, c);
244            assert_eq!(str_wildcard_idx, s);
245            assert_eq!(&String::from_utf8(unescaped_bytes).unwrap(), ub);
246        }
247    }
248}