1use itertools::Itertools;
16use risingwave_common::util::column_index_mapping::ColIndexMapping;
17use risingwave_pb::expr::expr_node::RexNode;
18use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction};
19use risingwave_sqlparser::ast::{
20 Array, CreateSink, CreateSinkStatement, CreateSourceStatement, CreateSubscriptionStatement,
21 Distinct, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident, ObjectName,
22 Query, SelectItem, SetExpr, Statement, TableAlias, TableFactor, TableWithJoins,
23};
24use risingwave_sqlparser::parser::Parser;
25
26pub fn alter_relation_rename(definition: &str, new_name: &str) -> String {
31 if definition.is_empty() {
34 tracing::warn!("found empty definition when renaming relation, ignored.");
35 return definition.into();
36 }
37 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
38 let mut stmt = ast
39 .into_iter()
40 .exactly_one()
41 .expect("should contains only one statement");
42
43 match &mut stmt {
44 Statement::CreateTable { name, .. }
45 | Statement::CreateView { name, .. }
46 | Statement::CreateIndex { name, .. }
47 | Statement::CreateSource {
48 stmt: CreateSourceStatement {
49 source_name: name, ..
50 },
51 }
52 | Statement::CreateSubscription {
53 stmt:
54 CreateSubscriptionStatement {
55 subscription_name: name,
56 ..
57 },
58 }
59 | Statement::CreateSink {
60 stmt: CreateSinkStatement {
61 sink_name: name, ..
62 },
63 } => replace_table_name(name, new_name),
64 _ => unreachable!(),
65 };
66
67 stmt.to_string()
68}
69
70pub fn alter_relation_rename_refs(definition: &str, from: &str, to: &str) -> String {
73 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
74 let mut stmt = ast
75 .into_iter()
76 .exactly_one()
77 .expect("should contains only one statement");
78
79 match &mut stmt {
80 Statement::CreateTable {
81 query: Some(query), ..
82 }
83 | Statement::CreateView { query, .. }
84 | Statement::Query(query) | Statement::CreateSink {
86 stmt:
87 CreateSinkStatement {
88 sink_from: CreateSink::AsQuery(query),
89 into_table_name: None,
90 ..
91 },
92 } => {
93 QueryRewriter::rewrite_query(query, from, to);
94 }
95 Statement::CreateIndex { table_name, .. }
96 | Statement::CreateSink {
97 stmt:
98 CreateSinkStatement {
99 sink_from: CreateSink::From(table_name),
100 into_table_name: None,
101 ..
102 },
103 }| Statement::CreateSubscription {
104 stmt:
105 CreateSubscriptionStatement {
106 subscription_from: table_name,
107 ..
108 },
109 } => replace_table_name(table_name, to),
110 Statement::CreateSink {
111 stmt: CreateSinkStatement {
112 sink_from,
113 into_table_name: Some(table_name),
114 ..
115 }
116 } => {
117 let idx = table_name.0.len() - 1;
118 if table_name.0[idx].real_value() == from {
119 table_name.0[idx] = Ident::new_unchecked(to);
120 } else {
121 match sink_from {
122 CreateSink::From(table_name) => replace_table_name(table_name, to),
123 CreateSink::AsQuery(query) => QueryRewriter::rewrite_query(query, from, to),
124 }
125 }
126 }
127 _ => unreachable!(),
128 };
129 stmt.to_string()
130}
131
132fn replace_table_name(table_name: &mut ObjectName, to: &str) {
135 let idx = table_name.0.len() - 1;
136 table_name.0[idx] = Ident::new_unchecked(to);
137}
138
139struct QueryRewriter<'a> {
142 from: &'a str,
143 to: &'a str,
144}
145
146impl QueryRewriter<'_> {
147 fn rewrite_query(query: &mut Query, from: &str, to: &str) {
148 let rewriter = QueryRewriter { from, to };
149 rewriter.visit_query(query)
150 }
151
152 fn visit_query(&self, query: &mut Query) {
154 if let Some(with) = &mut query.with {
155 for cte_table in &mut with.cte_tables {
156 match &mut cte_table.cte_inner {
157 risingwave_sqlparser::ast::CteInner::Query(query) => self.visit_query(query),
158 risingwave_sqlparser::ast::CteInner::ChangeLog(name) => {
159 let idx = name.0.len() - 1;
160 if name.0[idx].real_value() == self.from {
161 name.0[idx] = Ident::with_quote_unchecked('"', self.to);
162 }
163 }
164 }
165 }
166 }
167 self.visit_set_expr(&mut query.body);
168 for expr in &mut query.order_by {
169 self.visit_expr(&mut expr.expr);
170 }
171 }
172
173 fn visit_table_factor(&self, table_factor: &mut TableFactor) {
185 match table_factor {
186 TableFactor::Table { name, alias, .. } => {
187 let idx = name.0.len() - 1;
188 if name.0[idx].real_value() == self.from {
189 if alias.is_none() {
190 *alias = Some(TableAlias {
191 name: Ident::new_unchecked(self.from),
192 columns: vec![],
193 });
194 }
195 name.0[idx] = Ident::new_unchecked(self.to);
196 }
197 }
198 TableFactor::Derived { subquery, .. } => self.visit_query(subquery),
199 TableFactor::TableFunction { args, .. } => {
200 for arg in args {
201 self.visit_function_arg(arg);
202 }
203 }
204 TableFactor::NestedJoin(table_with_joins) => {
205 self.visit_table_with_joins(table_with_joins);
206 }
207 }
208 }
209
210 fn visit_table_with_joins(&self, table_with_joins: &mut TableWithJoins) {
212 self.visit_table_factor(&mut table_with_joins.relation);
213 for join in &mut table_with_joins.joins {
214 self.visit_table_factor(&mut join.relation);
215 }
216 }
217
218 fn visit_set_expr(&self, set_expr: &mut SetExpr) {
220 match set_expr {
221 SetExpr::Select(select) => {
222 if let Distinct::DistinctOn(exprs) = &mut select.distinct {
223 for expr in exprs {
224 self.visit_expr(expr);
225 }
226 }
227 for select_item in &mut select.projection {
228 self.visit_select_item(select_item);
229 }
230 for from_item in &mut select.from {
231 self.visit_table_with_joins(from_item);
232 }
233 if let Some(where_clause) = &mut select.selection {
234 self.visit_expr(where_clause);
235 }
236 for expr in &mut select.group_by {
237 self.visit_expr(expr);
238 }
239 if let Some(having) = &mut select.having {
240 self.visit_expr(having);
241 }
242 }
243 SetExpr::Query(query) => self.visit_query(query),
244 SetExpr::SetOperation { left, right, .. } => {
245 self.visit_set_expr(left);
246 self.visit_set_expr(right);
247 }
248 SetExpr::Values(_) => {}
249 }
250 }
251
252 fn visit_function_arg(&self, function_arg: &mut FunctionArg) {
254 match function_arg {
255 FunctionArg::Unnamed(arg) | FunctionArg::Named { arg, .. } => match arg {
256 FunctionArgExpr::Expr(expr) | FunctionArgExpr::ExprQualifiedWildcard(expr, _) => {
257 self.visit_expr(expr)
258 }
259 FunctionArgExpr::QualifiedWildcard(_, None) | FunctionArgExpr::Wildcard(None) => {}
260 FunctionArgExpr::QualifiedWildcard(_, Some(exprs))
261 | FunctionArgExpr::Wildcard(Some(exprs)) => {
262 for expr in exprs {
263 self.visit_expr(expr);
264 }
265 }
266 },
267 }
268 }
269
270 fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) {
271 for arg in &mut arg_list.args {
272 self.visit_function_arg(arg);
273 }
274 for expr in &mut arg_list.order_by {
275 self.visit_expr(&mut expr.expr)
276 }
277 }
278
279 fn visit_function(&self, function: &mut Function) {
281 self.visit_function_arg_list(&mut function.arg_list);
282 if let Some(over) = &mut function.over {
283 for expr in &mut over.partition_by {
284 self.visit_expr(expr);
285 }
286 for expr in &mut over.order_by {
287 self.visit_expr(&mut expr.expr);
288 }
289 }
290 }
291
292 fn visit_expr(&self, expr: &mut Expr) {
294 match expr {
295 Expr::FieldIdentifier(expr, ..)
296 | Expr::IsNull(expr)
297 | Expr::IsNotNull(expr)
298 | Expr::IsTrue(expr)
299 | Expr::IsNotTrue(expr)
300 | Expr::IsFalse(expr)
301 | Expr::IsNotFalse(expr)
302 | Expr::IsUnknown(expr)
303 | Expr::IsNotUnknown(expr)
304 | Expr::IsJson { expr, .. }
305 | Expr::InList { expr, .. }
306 | Expr::SomeOp(expr)
307 | Expr::AllOp(expr)
308 | Expr::UnaryOp { expr, .. }
309 | Expr::Cast { expr, .. }
310 | Expr::TryCast { expr, .. }
311 | Expr::AtTimeZone {
312 timestamp: expr, ..
313 }
314 | Expr::Extract { expr, .. }
315 | Expr::Substring { expr, .. }
316 | Expr::Overlay { expr, .. }
317 | Expr::Trim { expr, .. }
318 | Expr::Nested(expr)
319 | Expr::Index { obj: expr, .. }
320 | Expr::ArrayRangeIndex { obj: expr, .. } => self.visit_expr(expr),
321
322 Expr::Position { substring, string } => {
323 self.visit_expr(substring);
324 self.visit_expr(string);
325 }
326
327 Expr::InSubquery { expr, subquery, .. } => {
328 self.visit_expr(expr);
329 self.visit_query(subquery);
330 }
331 Expr::Between {
332 expr, low, high, ..
333 } => {
334 self.visit_expr(expr);
335 self.visit_expr(low);
336 self.visit_expr(high);
337 }
338 Expr::Like {
339 expr, pattern: pat, ..
340 } => {
341 self.visit_expr(expr);
342 self.visit_expr(pat);
343 }
344 Expr::ILike {
345 expr, pattern: pat, ..
346 } => {
347 self.visit_expr(expr);
348 self.visit_expr(pat);
349 }
350 Expr::SimilarTo {
351 expr, pattern: pat, ..
352 } => {
353 self.visit_expr(expr);
354 self.visit_expr(pat);
355 }
356
357 Expr::IsDistinctFrom(expr1, expr2)
358 | Expr::IsNotDistinctFrom(expr1, expr2)
359 | Expr::BinaryOp {
360 left: expr1,
361 right: expr2,
362 ..
363 } => {
364 self.visit_expr(expr1);
365 self.visit_expr(expr2);
366 }
367 Expr::Function(function) => self.visit_function(function),
368 Expr::Exists(query) | Expr::Subquery(query) | Expr::ArraySubquery(query) => {
369 self.visit_query(query)
370 }
371
372 Expr::GroupingSets(exprs_vec) | Expr::Cube(exprs_vec) | Expr::Rollup(exprs_vec) => {
373 for exprs in exprs_vec {
374 for expr in exprs {
375 self.visit_expr(expr);
376 }
377 }
378 }
379
380 Expr::Row(exprs) | Expr::Array(Array { elem: exprs, .. }) => {
381 for expr in exprs {
382 self.visit_expr(expr);
383 }
384 }
385 Expr::Map { entries } => {
386 for (key, value) in entries {
387 self.visit_expr(key);
388 self.visit_expr(value);
389 }
390 }
391
392 Expr::LambdaFunction { body, args: _ } => self.visit_expr(body),
393
394 Expr::Identifier(_)
396 | Expr::CompoundIdentifier(_)
397 | Expr::Collate { .. }
398 | Expr::Value(_)
399 | Expr::Parameter { .. }
400 | Expr::TypedString { .. }
401 | Expr::Case { .. } => {}
402 }
403 }
404
405 fn visit_select_item(&self, select_item: &mut SelectItem) {
407 match select_item {
408 SelectItem::UnnamedExpr(expr)
409 | SelectItem::ExprQualifiedWildcard(expr, _)
410 | SelectItem::ExprWithAlias { expr, .. } => self.visit_expr(expr),
411 SelectItem::QualifiedWildcard(_, None) | SelectItem::Wildcard(None) => {}
412 SelectItem::QualifiedWildcard(_, Some(exprs)) | SelectItem::Wildcard(Some(exprs)) => {
413 for expr in exprs {
414 self.visit_expr(expr);
415 }
416 }
417 }
418 }
419}
420
421pub struct ReplaceTableExprRewriter {
422 pub table_col_index_mapping: ColIndexMapping,
423}
424
425impl ReplaceTableExprRewriter {
426 pub fn rewrite_expr(&self, expr: &mut ExprNode) {
427 let rex_node = expr.rex_node.as_mut().unwrap();
428 match rex_node {
429 RexNode::InputRef(input_col_idx) => {
430 *input_col_idx = self.table_col_index_mapping.map(*input_col_idx as usize) as u32
431 }
432 RexNode::Constant(_) => {}
433 RexNode::Udf(udf) => self.rewrite_udf(udf),
434 RexNode::FuncCall(function_call) => self.rewrite_function_call(function_call),
435 RexNode::Now(_) => {}
436 }
437 }
438
439 fn rewrite_udf(&self, udf: &mut UserDefinedFunction) {
440 udf.children
441 .iter_mut()
442 .for_each(|expr| self.rewrite_expr(expr));
443 }
444
445 fn rewrite_function_call(&self, function_call: &mut FunctionCall) {
446 function_call
447 .children
448 .iter_mut()
449 .for_each(|expr| self.rewrite_expr(expr));
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_alter_table_rename() {
459 let definition = "CREATE TABLE foo (a int, b int)";
460 let new_name = "bar";
461 let expected = "CREATE TABLE bar (a INT, b INT)";
462 let actual = alter_relation_rename(definition, new_name);
463 assert_eq!(expected, actual);
464 }
465
466 #[test]
467 fn test_rename_index_refs() {
468 let definition = "CREATE INDEX idx1 ON foo(v1 DESC, v2)";
469 let from = "foo";
470 let to = "bar";
471 let expected = "CREATE INDEX idx1 ON bar(v1 DESC, v2)";
472 let actual = alter_relation_rename_refs(definition, from, to);
473 assert_eq!(expected, actual);
474 }
475
476 #[test]
477 fn test_rename_sink_refs() {
478 let definition =
479 "CREATE SINK sink_t FROM foo WITH (connector = 'kafka', format = 'append_only')";
480 let from = "foo";
481 let to = "bar";
482 let expected =
483 "CREATE SINK sink_t FROM bar WITH (connector = 'kafka', format = 'append_only')";
484 let actual = alter_relation_rename_refs(definition, from, to);
485 assert_eq!(expected, actual);
486 }
487
488 #[test]
489 fn test_rename_with_alias_refs() {
490 let definition =
491 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM foo";
492 let from = "foo";
493 let to = "bar";
494 let expected =
495 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM bar AS foo";
496 let actual = alter_relation_rename_refs(definition, from, to);
497 assert_eq!(expected, actual);
498
499 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
500 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM bar AS foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
501 let actual = alter_relation_rename_refs(definition, from, to);
502 assert_eq!(expected, actual);
503
504 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM foo AS bar WHERE bar.v1 = 1";
505 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM bar AS bar WHERE bar.v1 = 1";
506 let actual = alter_relation_rename_refs(definition, from, to);
507 assert_eq!(expected, actual);
508 }
509
510 #[test]
511 fn test_rename_with_complex_funcs() {
512 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
513 agg1(\
514 foo.v1, func2(foo.v2) \
515 ORDER BY \
516 (SELECT foo.v3 FROM foo), \
517 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM foo) ORDER BY (SELECT foo.v6 FROM foo)) FROM foo)\
518 ) \
519 FROM foo";
520 let from = "foo";
521 let to = "bar";
522 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
523 agg1(\
524 foo.v1, func2(foo.v2) \
525 ORDER BY \
526 (SELECT foo.v3 FROM bar AS foo), \
527 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM bar AS foo) ORDER BY (SELECT foo.v6 FROM bar AS foo)) FROM bar AS foo)\
528 ) \
529 FROM bar AS foo";
530 let actual = alter_relation_rename_refs(definition, from, to);
531 assert_eq!(expected, actual);
532 }
533}