risingwave_frontend/optimizer/rule/
rewrite_like_expr_rule.rs1use std::cmp::min;
16use std::str::from_utf8;
17
18use risingwave_common::types::{DataType, ScalarImpl};
19
20use super::{BoxedRule, Rule};
21use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, Literal};
22use crate::optimizer::PlanRef;
23use crate::optimizer::plan_node::{ExprRewritable, LogicalFilter};
24
25pub struct RewriteLikeExprRule {}
30impl Rule for RewriteLikeExprRule {
31 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
32 let filter: &LogicalFilter = plan.as_logical_filter()?;
33 if filter.predicate().conjunctions.iter().any(|expr| {
34 let mut has_like = HasSimpleLikeExprVisitor { has: false };
35 has_like.visit_expr(expr);
36 has_like.has
37 }) {
38 let mut rewriter = LikeExprRewriter {};
39 Some(filter.rewrite_exprs(&mut rewriter))
40 } else {
41 None
42 }
43 }
44}
45
46struct HasSimpleLikeExprVisitor {
47 has: bool,
48}
49
50impl ExprVisitor for HasSimpleLikeExprVisitor {
51 fn visit_function_call(&mut self, func_call: &FunctionCall) {
52 if func_call.func_type() == ExprType::Like
55 && func_call.inputs().len() == 2
56 && let (_, ExprImpl::InputRef(_), ExprImpl::Literal(_)) =
57 func_call.clone().decompose_as_binary()
58 {
59 self.has = true;
60 } else {
61 func_call
62 .inputs()
63 .iter()
64 .for_each(|expr| self.visit_expr(expr));
65 }
66 }
67}
68
69struct LikeExprRewriter {}
70
71impl LikeExprRewriter {
72 fn cal_index_and_unescape(bytes: &[u8]) -> (Option<usize>, Option<usize>, Vec<u8>) {
73 let mut unescaped_bytes = vec![];
75 let mut char_wildcard_idx: Option<usize> = None;
77 let mut str_wildcard_idx: Option<usize> = None;
79
80 let mut in_escape = false;
81 const ESCAPE: u8 = b'\\';
82 for &c in bytes {
83 if !in_escape && c == ESCAPE {
84 in_escape = true;
85 continue;
86 }
87 if in_escape {
88 in_escape = false;
89 unescaped_bytes.push(c);
90 continue;
91 }
92 unescaped_bytes.push(c);
93 if c == b'_' {
94 char_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
95 } else if c == b'%' {
96 str_wildcard_idx.get_or_insert(unescaped_bytes.len() - 1);
97 }
98 }
99 assert!(!in_escape);
102 (char_wildcard_idx, str_wildcard_idx, unescaped_bytes)
103 }
104}
105
106impl ExprRewriter for LikeExprRewriter {
107 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
108 let (func_type, inputs, ret) = func_call.decompose();
109 let inputs: Vec<ExprImpl> = inputs
110 .into_iter()
111 .map(|expr| self.rewrite_expr(expr))
112 .collect();
113 let func_call = FunctionCall::new_unchecked(func_type, inputs, ret.clone());
114
115 if func_call.func_type() != ExprType::Like {
116 return func_call.into();
117 }
118
119 let (_, ExprImpl::InputRef(x), ExprImpl::Literal(y)) =
120 func_call.clone().decompose_as_binary()
121 else {
122 return func_call.into();
123 };
124
125 if y.return_type() != DataType::Varchar {
126 return func_call.into();
127 }
128
129 let data = y.get_data();
130 let Some(ScalarImpl::Utf8(data)) = data else {
131 return func_call.into();
132 };
133
134 let bytes = data.as_bytes();
135 let len = bytes.len();
136
137 let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
138 Self::cal_index_and_unescape(bytes);
139
140 let idx = match (char_wildcard_idx, str_wildcard_idx) {
141 (Some(a), Some(b)) => min(a, b),
142 (Some(idx), None) => idx,
143 (None, Some(idx)) => idx,
144 (None, None) => {
145 let Ok(unescaped_y) = String::from_utf8(unescaped_bytes) else {
146 return func_call.into();
148 };
149 let inputs = vec![
150 ExprImpl::InputRef(x),
151 ExprImpl::literal_varchar(unescaped_y),
152 ];
153 let func_call = FunctionCall::new_unchecked(ExprType::Equal, inputs, ret);
154 return func_call.into();
155 }
156 };
157
158 if idx == 0 {
159 return func_call.into();
160 }
161
162 let (low, high) = {
163 let low = unescaped_bytes[0..idx].to_owned();
164 if low[idx - 1] == 255 {
165 return func_call.into();
166 }
167 let mut high = low.clone();
168 high[idx - 1] += 1;
169 match (from_utf8(&low), from_utf8(&high)) {
170 (Ok(low), Ok(high)) => (low.to_owned(), high.to_owned()),
171 _ => {
172 return func_call.into();
173 }
174 }
175 };
176
177 let between = FunctionCall::new_unchecked(
178 ExprType::And,
179 vec![
180 FunctionCall::new_unchecked(
181 ExprType::GreaterThanOrEqual,
182 vec![
183 ExprImpl::InputRef(x.clone()),
184 ExprImpl::Literal(
185 Literal::new(Some(ScalarImpl::Utf8(low.into())), DataType::Varchar)
186 .into(),
187 ),
188 ],
189 DataType::Boolean,
190 )
191 .into(),
192 FunctionCall::new_unchecked(
193 ExprType::LessThan,
194 vec![
195 ExprImpl::InputRef(x),
196 ExprImpl::Literal(
197 Literal::new(Some(ScalarImpl::Utf8(high.into())), DataType::Varchar)
198 .into(),
199 ),
200 ],
201 DataType::Boolean,
202 )
203 .into(),
204 ],
205 DataType::Boolean,
206 );
207
208 if idx == len - 1 {
209 between.into()
210 } else {
211 FunctionCall::new_unchecked(
212 ExprType::And,
213 vec![between.into(), func_call.into()],
214 DataType::Boolean,
215 )
216 .into()
217 }
218 }
219}
220
221impl RewriteLikeExprRule {
222 pub fn create() -> BoxedRule {
223 Box::new(RewriteLikeExprRule {})
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 #[test]
230 fn test_cal_index_and_unescape() {
231 #[expect(clippy::type_complexity, reason = "in testcase")]
232 let testcases: [(&str, (Option<usize>, Option<usize>, &str)); 7] = [
233 ("testname", (None, None, "testname")),
234 ("test_name", (Some(4), None, "test_name")),
235 ("test_name_2", (Some(4), None, "test_name_2")),
236 ("test%name", (None, Some(4), "test%name")),
237 (r"test\_name", (None, None, "test_name")),
238 (r"test\_name_2", (Some(9), None, "test_name_2")),
239 (r"test\\_name_2", (Some(5), None, r"test\_name_2")),
240 ];
241
242 for (pattern, (c, s, ub)) in testcases {
243 let input = pattern.as_bytes();
244 let (char_wildcard_idx, str_wildcard_idx, unescaped_bytes) =
245 super::LikeExprRewriter::cal_index_and_unescape(input);
246 assert_eq!(char_wildcard_idx, c);
247 assert_eq!(str_wildcard_idx, s);
248 assert_eq!(&String::from_utf8(unescaped_bytes).unwrap(), ub);
249 }
250 }
251}