risingwave_frontend/optimizer/rule/
rewrite_like_expr_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::str::from_utf8;
17
18use risingwave_common::types::{DataType, ScalarImpl};
19
20use super::{BoxedRule, Rule};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, Literal};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter};
24
25/// `RewriteLikeExprRule` rewrites simple like expression, so that it can benefit from index selection.
26/// col like 'ABC' => col = 'ABC'
27/// col like 'ABC%' => col >= 'ABC' and col < 'ABD'
28/// col like 'ABC%E' => col >= 'ABC' and col < 'ABD' and col like 'ABC%E'
29pub struct RewriteLikeExprRule {}
30impl Rule for RewriteLikeExprRule {
31    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
32        let filter: &LogicalFilter = plan.as_logical_filter()?;
33        if filter.predicate().conjunctions.iter().any(|expr| {
34            let mut has_like = HasSimpleLikeExprVisitor { has: false };
35            has_like.visit_expr(expr);
36            has_like.has
37        }) {
38            let mut rewriter = LikeExprRewriter {};
39            Some(filter.rewrite_exprs(&mut rewriter))
40        } else {
41            None
42        }
43    }
44}
45
46struct HasSimpleLikeExprVisitor {
47    has: bool,
48}
49
50impl ExprVisitor for HasSimpleLikeExprVisitor {
51    fn visit_function_call(&mut self, func_call: &FunctionCall) {
52        // Simple like expression is a binary operation, e.g. col like 'ABC%'
53        // While col like 'ABC%' ESCAPE '!' is a complex like expression.
54        if func_call.func_type() == ExprType::Like
55            && func_call.inputs().len() == 2
56            && let (_, ExprImpl::InputRef(_), ExprImpl::Literal(_)) =
57                func_call.clone().decompose_as_binary()
58        {
59            self.has = true;
60        } else {
61            func_call
62                .inputs()
63                .iter()
64                .for_each(|expr| self.visit_expr(expr));
65        }
66    }
67}
68
69struct LikeExprRewriter {}
70
71impl LikeExprRewriter {
72    fn cal_index_and_unescape(bytes: &[u8]) -> (Option<usize>, Option<usize>, Vec<u8>) {
73        // The pattern without escape character.
74        let mut unescaped_bytes = vec![];
75        // The idx of `_` in the `unescaped_bytes`.
76        let mut char_wildcard_idx: Option<usize> = None;
77        // The idx of `%` in the `unescaped_bytes`.
78        let mut str_wildcard_idx: Option<usize> = None;
79
80        let mut in_escape = false;
81        const ESCAPE: u8 = b'\\';
82        for &c in bytes {
83            if !in_escape && c == ESCAPE {
84                in_escape = true;
85                continue;
86            }
87            if in_escape {
88                in_escape = false;
89                unescaped_bytes.push(c);
90                continue;
91            }
92            unescaped_bytes.push(c);
93            if c == b'_' {
94                char_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
95            } else if c == b'%' {
96                str_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
97            }
98        }
99        // Note: we only support `\\` as the escape character now, and it can't be positioned at the
100        // end of string, which will be banned by parser.
101        assert!(!in_escape);
102        (char_wildcard_idx, str_wildcard_idx, unescaped_bytes)
103    }
104}
105
106impl ExprRewriter for LikeExprRewriter {
107    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
108        let (func_type, inputs, ret) = func_call.decompose();
109        let inputs: Vec<ExprImpl> = inputs
110            .into_iter()
111            .map(|expr| self.rewrite_expr(expr))
112            .collect();
113        let func_call = FunctionCall::new_unchecked(func_type, inputs, ret.clone());
114
115        if func_call.func_type() != ExprType::Like {
116            return func_call.into();
117        }
118
119        let (_, ExprImpl::InputRef(x), ExprImpl::Literal(y)) =
120            func_call.clone().decompose_as_binary()
121        else {
122            return func_call.into();
123        };
124
125        if y.return_type() != DataType::Varchar {
126            return func_call.into();
127        }
128
129        let data = y.get_data();
130        let Some(ScalarImpl::Utf8(data)) = data else {
131            return func_call.into();
132        };
133
134        let bytes = data.as_bytes();
135        let len = bytes.len();
136
137        let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
138            Self::cal_index_and_unescape(bytes);
139
140        let idx = match (char_wildcard_idx, str_wildcard_idx) {
141            (Some(a), Some(b)) => min(a, b),
142            (Some(idx), None) => idx,
143            (None, Some(idx)) => idx,
144            (None, None) => {
145                let Ok(unescaped_y) = String::from_utf8(unescaped_bytes) else {
146                    // FIXME: We should definitely treat the argument as UTF-8 string instead of bytes, but currently, we just fallback here.
147                    return func_call.into();
148                };
149                let inputs = vec![
150                    ExprImpl::InputRef(x),
151                    ExprImpl::literal_varchar(unescaped_y),
152                ];
153                let func_call = FunctionCall::new_unchecked(ExprType::Equal, inputs, ret);
154                return func_call.into();
155            }
156        };
157
158        if idx == 0 {
159            return func_call.into();
160        }
161
162        let (low, high) = {
163            let low = unescaped_bytes[0..idx].to_owned();
164            if low[idx - 1] == 255 {
165                return func_call.into();
166            }
167            let mut high = low.clone();
168            high[idx - 1] += 1;
169            match (from_utf8(&low), from_utf8(&high)) {
170                (Ok(low), Ok(high)) => (low.to_owned(), high.to_owned()),
171                _ => {
172                    return func_call.into();
173                }
174            }
175        };
176
177        let between = FunctionCall::new_unchecked(
178            ExprType::And,
179            vec![
180                FunctionCall::new_unchecked(
181                    ExprType::GreaterThanOrEqual,
182                    vec![
183                        ExprImpl::InputRef(x.clone()),
184                        ExprImpl::Literal(
185                            Literal::new(Some(ScalarImpl::Utf8(low.into())), DataType::Varchar)
186                                .into(),
187                        ),
188                    ],
189                    DataType::Boolean,
190                )
191                .into(),
192                FunctionCall::new_unchecked(
193                    ExprType::LessThan,
194                    vec![
195                        ExprImpl::InputRef(x),
196                        ExprImpl::Literal(
197                            Literal::new(Some(ScalarImpl::Utf8(high.into())), DataType::Varchar)
198                                .into(),
199                        ),
200                    ],
201                    DataType::Boolean,
202                )
203                .into(),
204            ],
205            DataType::Boolean,
206        );
207
208        if idx == len - 1 {
209            between.into()
210        } else {
211            FunctionCall::new_unchecked(
212                ExprType::And,
213                vec![between.into(), func_call.into()],
214                DataType::Boolean,
215            )
216            .into()
217        }
218    }
219}
220
221impl RewriteLikeExprRule {
222    pub fn create() -> BoxedRule {
223        Box::new(RewriteLikeExprRule {})
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    #[test]
230    fn test_cal_index_and_unescape() {
231        #[expect(clippy::type_complexity, reason = "in testcase")]
232        let testcases: [(&str, (Option<usize>, Option<usize>, &str)); 7] = [
233            ("testname", (None, None, "testname")),
234            ("test_name", (Some(4), None, "test_name")),
235            ("test_name_2", (Some(4), None, "test_name_2")),
236            ("test%name", (None, Some(4), "test%name")),
237            (r"test\_name", (None, None, "test_name")),
238            (r"test\_name_2", (Some(9), None, "test_name_2")),
239            (r"test\\_name_2", (Some(5), None, r"test\_name_2")),
240        ];
241
242        for (pattern, (c, s, ub)) in testcases {
243            let input = pattern.as_bytes();
244            let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
245                super::LikeExprRewriter::cal_index_and_unescape(input);
246            assert_eq!(char_wildcard_idx, c);
247            assert_eq!(str_wildcard_idx, s);
248            assert_eq!(&String::from_utf8(unescaped_bytes).unwrap(), ub);
249        }
250    }
251}