risingwave_frontend/optimizer/rule/
rewrite_like_expr_rule.rs1use std::cmp::min;
16use std::str::from_utf8;
17
18use risingwave_common::types::{DataType, ScalarImpl};
19
20use super::{BoxedRule, Rule};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, Literal};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter};
24
25pub struct RewriteLikeExprRule {}
30impl Rule for RewriteLikeExprRule {
31 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
32 let filter: &LogicalFilter = plan.as_logical_filter()?;
33 if filter.predicate().conjunctions.iter().any(|expr| {
34 let mut has_like = HasLikeExprVisitor { has: false };
35 has_like.visit_expr(expr);
36 has_like.has
37 }) {
38 let mut rewriter = LikeExprRewriter {};
39 Some(filter.rewrite_exprs(&mut rewriter))
40 } else {
41 None
42 }
43 }
44}
45
46struct HasLikeExprVisitor {
47 has: bool,
48}
49
50impl ExprVisitor for HasLikeExprVisitor {
51 fn visit_function_call(&mut self, func_call: &FunctionCall) {
52 if func_call.func_type() == ExprType::Like
53 && let (_, ExprImpl::InputRef(_), ExprImpl::Literal(_)) =
54 func_call.clone().decompose_as_binary()
55 {
56 self.has = true;
57 } else {
58 func_call
59 .inputs()
60 .iter()
61 .for_each(|expr| self.visit_expr(expr));
62 }
63 }
64}
65
66struct LikeExprRewriter {}
67
68impl LikeExprRewriter {
69 fn cal_index_and_unescape(bytes: &[u8]) -> (Option<usize>, Option<usize>, Vec<u8>) {
70 let mut unescaped_bytes = vec![];
72 let mut char_wildcard_idx: Option<usize> = None;
74 let mut str_wildcard_idx: Option<usize> = None;
76
77 let mut in_escape = false;
78 const ESCAPE: u8 = b'\\';
79 for &c in bytes {
80 if !in_escape && c == ESCAPE {
81 in_escape = true;
82 continue;
83 }
84 if in_escape {
85 in_escape = false;
86 unescaped_bytes.push(c);
87 continue;
88 }
89 unescaped_bytes.push(c);
90 if c == b'_' {
91 char_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
92 } else if c == b'%' {
93 str_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
94 }
95 }
96 assert!(!in_escape);
99 (char_wildcard_idx, str_wildcard_idx, unescaped_bytes)
100 }
101}
102
103impl ExprRewriter for LikeExprRewriter {
104 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
105 let (func_type, inputs, ret) = func_call.decompose();
106 let inputs: Vec<ExprImpl> = inputs
107 .into_iter()
108 .map(|expr| self.rewrite_expr(expr))
109 .collect();
110 let func_call = FunctionCall::new_unchecked(func_type, inputs, ret.clone());
111
112 if func_call.func_type() != ExprType::Like {
113 return func_call.into();
114 }
115
116 let (_, ExprImpl::InputRef(x), ExprImpl::Literal(y)) =
117 func_call.clone().decompose_as_binary()
118 else {
119 return func_call.into();
120 };
121
122 if y.return_type() != DataType::Varchar {
123 return func_call.into();
124 }
125
126 let data = y.get_data();
127 let Some(ScalarImpl::Utf8(data)) = data else {
128 return func_call.into();
129 };
130
131 let bytes = data.as_bytes();
132 let len = bytes.len();
133
134 let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
135 Self::cal_index_and_unescape(bytes);
136
137 let idx = match (char_wildcard_idx, str_wildcard_idx) {
138 (Some(a), Some(b)) => min(a, b),
139 (Some(idx), None) => idx,
140 (None, Some(idx)) => idx,
141 (None, None) => {
142 let Ok(unescaped_y) = String::from_utf8(unescaped_bytes) else {
143 return func_call.into();
145 };
146 let inputs = vec![
147 ExprImpl::InputRef(x),
148 ExprImpl::literal_varchar(unescaped_y),
149 ];
150 let func_call = FunctionCall::new_unchecked(ExprType::Equal, inputs, ret);
151 return func_call.into();
152 }
153 };
154
155 if idx == 0 {
156 return func_call.into();
157 }
158
159 let (low, high) = {
160 let low = unescaped_bytes[0..idx].to_owned();
161 if low[idx - 1] == 255 {
162 return func_call.into();
163 }
164 let mut high = low.clone();
165 high[idx - 1] += 1;
166 match (from_utf8(&low), from_utf8(&high)) {
167 (Ok(low), Ok(high)) => (low.to_owned(), high.to_owned()),
168 _ => {
169 return func_call.into();
170 }
171 }
172 };
173
174 let between = FunctionCall::new_unchecked(
175 ExprType::And,
176 vec![
177 FunctionCall::new_unchecked(
178 ExprType::GreaterThanOrEqual,
179 vec![
180 ExprImpl::InputRef(x.clone()),
181 ExprImpl::Literal(
182 Literal::new(Some(ScalarImpl::Utf8(low.into())), DataType::Varchar)
183 .into(),
184 ),
185 ],
186 DataType::Boolean,
187 )
188 .into(),
189 FunctionCall::new_unchecked(
190 ExprType::LessThan,
191 vec![
192 ExprImpl::InputRef(x),
193 ExprImpl::Literal(
194 Literal::new(Some(ScalarImpl::Utf8(high.into())), DataType::Varchar)
195 .into(),
196 ),
197 ],
198 DataType::Boolean,
199 )
200 .into(),
201 ],
202 DataType::Boolean,
203 );
204
205 if idx == len - 1 {
206 between.into()
207 } else {
208 FunctionCall::new_unchecked(
209 ExprType::And,
210 vec![between.into(), func_call.into()],
211 DataType::Boolean,
212 )
213 .into()
214 }
215 }
216}
217
218impl RewriteLikeExprRule {
219 pub fn create() -> BoxedRule {
220 Box::new(RewriteLikeExprRule {})
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 #[test]
227 fn test_cal_index_and_unescape() {
228 #[expect(clippy::type_complexity, reason = "in testcase")]
229 let testcases: [(&str, (Option<usize>, Option<usize>, &str)); 7] = [
230 ("testname", (None, None, "testname")),
231 ("test_name", (Some(4), None, "test_name")),
232 ("test_name_2", (Some(4), None, "test_name_2")),
233 ("test%name", (None, Some(4), "test%name")),
234 (r"test\_name", (None, None, "test_name")),
235 (r"test\_name_2", (Some(9), None, "test_name_2")),
236 (r"test\\_name_2", (Some(5), None, r"test\_name_2")),
237 ];
238
239 for (pattern, (c, s, ub)) in testcases {
240 let input = pattern.as_bytes();
241 let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
242 super::LikeExprRewriter::cal_index_and_unescape(input);
243 assert_eq!(char_wildcard_idx, c);
244 assert_eq!(str_wildcard_idx, s);
245 assert_eq!(&String::from_utf8(unescaped_bytes).unwrap(), ub);
246 }
247 }
248}