risingwave_frontend/optimizer/rule/
rewrite_like_expr_rule.rs1use std::cmp::min;
16use std::str::from_utf8;
17
18use risingwave_common::types::{DataType, ScalarImpl};
19
20use super::prelude::{PlanRef, *};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, Literal};
22use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter};
23
24pub struct RewriteLikeExprRule {}
29impl Rule<Logical> for RewriteLikeExprRule {
30 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
31 let filter: &LogicalFilter = plan.as_logical_filter()?;
32 if filter.predicate().conjunctions.iter().any(|expr| {
33 let mut has_like = HasSimpleLikeExprVisitor { has: false };
34 has_like.visit_expr(expr);
35 has_like.has
36 }) {
37 let mut rewriter = LikeExprRewriter {};
38 Some(filter.rewrite_exprs(&mut rewriter))
39 } else {
40 None
41 }
42 }
43}
44
45struct HasSimpleLikeExprVisitor {
46 has: bool,
47}
48
49impl ExprVisitor for HasSimpleLikeExprVisitor {
50 fn visit_function_call(&mut self, func_call: &FunctionCall) {
51 if func_call.func_type() == ExprType::Like
54 && func_call.inputs().len() == 2
55 && let (_, ExprImpl::InputRef(_), ExprImpl::Literal(_)) =
56 func_call.clone().decompose_as_binary()
57 {
58 self.has = true;
59 } else {
60 func_call
61 .inputs()
62 .iter()
63 .for_each(|expr| self.visit_expr(expr));
64 }
65 }
66}
67
68struct LikeExprRewriter {}
69
70impl LikeExprRewriter {
71 fn cal_index_and_unescape(bytes: &[u8]) -> (Option<usize>, Option<usize>, Vec<u8>) {
72 let mut unescaped_bytes = vec![];
74 let mut char_wildcard_idx: Option<usize> = None;
76 let mut str_wildcard_idx: Option<usize> = None;
78
79 let mut in_escape = false;
80 const ESCAPE: u8 = b'\\';
81 for &c in bytes {
82 if !in_escape && c == ESCAPE {
83 in_escape = true;
84 continue;
85 }
86 if in_escape {
87 in_escape = false;
88 unescaped_bytes.push(c);
89 continue;
90 }
91 unescaped_bytes.push(c);
92 if c == b'_' {
93 char_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
94 } else if c == b'%' {
95 str_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
96 }
97 }
98 assert!(!in_escape);
101 (char_wildcard_idx, str_wildcard_idx, unescaped_bytes)
102 }
103}
104
105impl ExprRewriter for LikeExprRewriter {
106 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
107 let (func_type, inputs, ret) = func_call.decompose();
108 let inputs: Vec<ExprImpl> = inputs
109 .into_iter()
110 .map(|expr| self.rewrite_expr(expr))
111 .collect();
112 let func_call = FunctionCall::new_unchecked(func_type, inputs, ret.clone());
113
114 if func_call.func_type() != ExprType::Like {
115 return func_call.into();
116 }
117
118 let (_, ExprImpl::InputRef(x), ExprImpl::Literal(y)) =
119 func_call.clone().decompose_as_binary()
120 else {
121 return func_call.into();
122 };
123
124 if y.return_type() != DataType::Varchar {
125 return func_call.into();
126 }
127
128 let data = y.get_data();
129 let Some(ScalarImpl::Utf8(data)) = data else {
130 return func_call.into();
131 };
132
133 let bytes = data.as_bytes();
134 let len = bytes.len();
135
136 let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
137 Self::cal_index_and_unescape(bytes);
138
139 let idx = match (char_wildcard_idx, str_wildcard_idx) {
140 (Some(a), Some(b)) => min(a, b),
141 (Some(idx), None) => idx,
142 (None, Some(idx)) => idx,
143 (None, None) => {
144 let Ok(unescaped_y) = String::from_utf8(unescaped_bytes) else {
145 return func_call.into();
147 };
148 let inputs = vec![
149 ExprImpl::InputRef(x),
150 ExprImpl::literal_varchar(unescaped_y),
151 ];
152 let func_call = FunctionCall::new_unchecked(ExprType::Equal, inputs, ret);
153 return func_call.into();
154 }
155 };
156
157 if idx == 0 {
158 return func_call.into();
159 }
160
161 let (low, high) = {
162 let low = unescaped_bytes[0..idx].to_owned();
163 if low[idx - 1] == 255 {
164 return func_call.into();
165 }
166 let mut high = low.clone();
167 high[idx - 1] += 1;
168 match (from_utf8(&low), from_utf8(&high)) {
169 (Ok(low), Ok(high)) => (low.to_owned(), high.to_owned()),
170 _ => {
171 return func_call.into();
172 }
173 }
174 };
175
176 let between = FunctionCall::new_unchecked(
177 ExprType::And,
178 vec![
179 FunctionCall::new_unchecked(
180 ExprType::GreaterThanOrEqual,
181 vec![
182 ExprImpl::InputRef(x.clone()),
183 ExprImpl::Literal(
184 Literal::new(Some(ScalarImpl::Utf8(low.into())), DataType::Varchar)
185 .into(),
186 ),
187 ],
188 DataType::Boolean,
189 )
190 .into(),
191 FunctionCall::new_unchecked(
192 ExprType::LessThan,
193 vec![
194 ExprImpl::InputRef(x),
195 ExprImpl::Literal(
196 Literal::new(Some(ScalarImpl::Utf8(high.into())), DataType::Varchar)
197 .into(),
198 ),
199 ],
200 DataType::Boolean,
201 )
202 .into(),
203 ],
204 DataType::Boolean,
205 );
206
207 if idx == len - 1 {
208 between.into()
209 } else {
210 FunctionCall::new_unchecked(
211 ExprType::And,
212 vec![between.into(), func_call.into()],
213 DataType::Boolean,
214 )
215 .into()
216 }
217 }
218}
219
220impl RewriteLikeExprRule {
221 pub fn create() -> BoxedRule {
222 Box::new(RewriteLikeExprRule {})
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 #[test]
229 fn test_cal_index_and_unescape() {
230 #[expect(clippy::type_complexity, reason = "in testcase")]
231 let testcases: [(&str, (Option<usize>, Option<usize>, &str)); 7] = [
232 ("testname", (None, None, "testname")),
233 ("test_name", (Some(4), None, "test_name")),
234 ("test_name_2", (Some(4), None, "test_name_2")),
235 ("test%name", (None, Some(4), "test%name")),
236 (r"test\_name", (None, None, "test_name")),
237 (r"test\_name_2", (Some(9), None, "test_name_2")),
238 (r"test\\_name_2", (Some(5), None, r"test\_name_2")),
239 ];
240
241 for (pattern, (c, s, ub)) in testcases {
242 let input = pattern.as_bytes();
243 let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
244 super::LikeExprRewriter::cal_index_and_unescape(input);
245 assert_eq!(char_wildcard_idx, c);
246 assert_eq!(str_wildcard_idx, s);
247 assert_eq!(&String::from_utf8(unescaped_bytes).unwrap(), ub);
248 }
249 }
250}