risingwave_frontend/optimizer/rule/
rewrite_like_expr_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::str::from_utf8;
17
18use risingwave_common::types::{DataType, ScalarImpl};
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, Literal};
22use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter};
23
24/// `RewriteLikeExprRule` rewrites simple like expression, so that it can benefit from index selection.
25/// col like 'ABC' => col = 'ABC'
26/// col like 'ABC%' => col >= 'ABC' and col < 'ABD'
27/// col like 'ABC%E' => col >= 'ABC' and col < 'ABD' and col like 'ABC%E'
28pub struct RewriteLikeExprRule {}
29impl Rule<Logical> for RewriteLikeExprRule {
30    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
31        let filter: &LogicalFilter = plan.as_logical_filter()?;
32        if filter.predicate().conjunctions.iter().any(|expr| {
33            let mut has_like = HasSimpleLikeExprVisitor { has: false };
34            has_like.visit_expr(expr);
35            has_like.has
36        }) {
37            let mut rewriter = LikeExprRewriter {};
38            Some(filter.rewrite_exprs(&mut rewriter))
39        } else {
40            None
41        }
42    }
43}
44
45struct HasSimpleLikeExprVisitor {
46    has: bool,
47}
48
49impl ExprVisitor for HasSimpleLikeExprVisitor {
50    fn visit_function_call(&mut self, func_call: &FunctionCall) {
51        // Simple like expression is a binary operation, e.g. col like 'ABC%'
52        // While col like 'ABC%' ESCAPE '!' is a complex like expression.
53        if func_call.func_type() == ExprType::Like
54            && func_call.inputs().len() == 2
55            && let (_, ExprImpl::InputRef(_), ExprImpl::Literal(_)) =
56                func_call.clone().decompose_as_binary()
57        {
58            self.has = true;
59        } else {
60            func_call
61                .inputs()
62                .iter()
63                .for_each(|expr| self.visit_expr(expr));
64        }
65    }
66}
67
68struct LikeExprRewriter {}
69
70impl LikeExprRewriter {
71    fn cal_index_and_unescape(bytes: &[u8]) -> (Option<usize>, Option<usize>, Vec<u8>) {
72        // The pattern without escape character.
73        let mut unescaped_bytes = vec![];
74        // The idx of `_` in the `unescaped_bytes`.
75        let mut char_wildcard_idx: Option<usize> = None;
76        // The idx of `%` in the `unescaped_bytes`.
77        let mut str_wildcard_idx: Option<usize> = None;
78
79        let mut in_escape = false;
80        const ESCAPE: u8 = b'\\';
81        for &c in bytes {
82            if !in_escape && c == ESCAPE {
83                in_escape = true;
84                continue;
85            }
86            if in_escape {
87                in_escape = false;
88                unescaped_bytes.push(c);
89                continue;
90            }
91            unescaped_bytes.push(c);
92            if c == b'_' {
93                char_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
94            } else if c == b'%' {
95                str_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
96            }
97        }
98        // Note: we only support `\\` as the escape character now, and it can't be positioned at the
99        // end of string, which will be banned by parser.
100        assert!(!in_escape);
101        (char_wildcard_idx, str_wildcard_idx, unescaped_bytes)
102    }
103}
104
105impl ExprRewriter for LikeExprRewriter {
106    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
107        let (func_type, inputs, ret) = func_call.decompose();
108        let inputs: Vec<ExprImpl> = inputs
109            .into_iter()
110            .map(|expr| self.rewrite_expr(expr))
111            .collect();
112        let func_call = FunctionCall::new_unchecked(func_type, inputs, ret.clone());
113
114        if func_call.func_type() != ExprType::Like {
115            return func_call.into();
116        }
117
118        let (_, ExprImpl::InputRef(x), ExprImpl::Literal(y)) =
119            func_call.clone().decompose_as_binary()
120        else {
121            return func_call.into();
122        };
123
124        if y.return_type() != DataType::Varchar {
125            return func_call.into();
126        }
127
128        let data = y.get_data();
129        let Some(ScalarImpl::Utf8(data)) = data else {
130            return func_call.into();
131        };
132
133        let bytes = data.as_bytes();
134        let len = bytes.len();
135
136        let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
137            Self::cal_index_and_unescape(bytes);
138
139        let idx = match (char_wildcard_idx, str_wildcard_idx) {
140            (Some(a), Some(b)) => min(a, b),
141            (Some(idx), None) => idx,
142            (None, Some(idx)) => idx,
143            (None, None) => {
144                let Ok(unescaped_y) = String::from_utf8(unescaped_bytes) else {
145                    // FIXME: We should definitely treat the argument as UTF-8 string instead of bytes, but currently, we just fallback here.
146                    return func_call.into();
147                };
148                let inputs = vec![
149                    ExprImpl::InputRef(x),
150                    ExprImpl::literal_varchar(unescaped_y),
151                ];
152                let func_call = FunctionCall::new_unchecked(ExprType::Equal, inputs, ret);
153                return func_call.into();
154            }
155        };
156
157        if idx == 0 {
158            return func_call.into();
159        }
160
161        let (low, high) = {
162            let low = unescaped_bytes[0..idx].to_owned();
163            if low[idx - 1] == 255 {
164                return func_call.into();
165            }
166            let mut high = low.clone();
167            high[idx - 1] += 1;
168            match (from_utf8(&low), from_utf8(&high)) {
169                (Ok(low), Ok(high)) => (low.to_owned(), high.to_owned()),
170                _ => {
171                    return func_call.into();
172                }
173            }
174        };
175
176        let between = FunctionCall::new_unchecked(
177            ExprType::And,
178            vec![
179                FunctionCall::new_unchecked(
180                    ExprType::GreaterThanOrEqual,
181                    vec![
182                        ExprImpl::InputRef(x.clone()),
183                        ExprImpl::Literal(
184                            Literal::new(Some(ScalarImpl::Utf8(low.into())), DataType::Varchar)
185                                .into(),
186                        ),
187                    ],
188                    DataType::Boolean,
189                )
190                .into(),
191                FunctionCall::new_unchecked(
192                    ExprType::LessThan,
193                    vec![
194                        ExprImpl::InputRef(x),
195                        ExprImpl::Literal(
196                            Literal::new(Some(ScalarImpl::Utf8(high.into())), DataType::Varchar)
197                                .into(),
198                        ),
199                    ],
200                    DataType::Boolean,
201                )
202                .into(),
203            ],
204            DataType::Boolean,
205        );
206
207        if idx == len - 1 {
208            between.into()
209        } else {
210            FunctionCall::new_unchecked(
211                ExprType::And,
212                vec![between.into(), func_call.into()],
213                DataType::Boolean,
214            )
215            .into()
216        }
217    }
218}
219
220impl RewriteLikeExprRule {
221    pub fn create() -> BoxedRule {
222        Box::new(RewriteLikeExprRule {})
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    #[test]
229    fn test_cal_index_and_unescape() {
230        #[expect(clippy::type_complexity, reason = "in testcase")]
231        let testcases: [(&str, (Option<usize>, Option<usize>, &str)); 7] = [
232            ("testname", (None, None, "testname")),
233            ("test_name", (Some(4), None, "test_name")),
234            ("test_name_2", (Some(4), None, "test_name_2")),
235            ("test%name", (None, Some(4), "test%name")),
236            (r"test\_name", (None, None, "test_name")),
237            (r"test\_name_2", (Some(9), None, "test_name_2")),
238            (r"test\\_name_2", (Some(5), None, r"test\_name_2")),
239        ];
240
241        for (pattern, (c, s, ub)) in testcases {
242            let input = pattern.as_bytes();
243            let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
244                super::LikeExprRewriter::cal_index_and_unescape(input);
245            assert_eq!(char_wildcard_idx, c);
246            assert_eq!(str_wildcard_idx, s);
247            assert_eq!(&String::from_utf8(unescaped_bytes).unwrap(), ub);
248        }
249    }
250}