1use itertools::Itertools;
16use risingwave_pb::expr::expr_node::{self, RexNode};
17use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction};
18use risingwave_pb::plan_common::PbColumnDesc;
19use risingwave_sqlparser::ast::{
20 Array, CdcTableInfo, CreateSink, CreateSinkStatement, CreateSourceStatement,
21 CreateSubscriptionStatement, Distinct, Expr, Function, FunctionArg, FunctionArgExpr,
22 FunctionArgList, Ident, ObjectName, Query, SelectItem, SetExpr, Statement, TableAlias,
23 TableFactor, TableWithJoins, Window,
24};
25use risingwave_sqlparser::parser::Parser;
26
27pub fn alter_relation_rename(definition: &str, new_name: &str) -> String {
32 if definition.is_empty() {
35 tracing::warn!("found empty definition when renaming relation, ignored.");
36 return definition.into();
37 }
38 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
39 let mut stmt = ast
40 .into_iter()
41 .exactly_one()
42 .expect("should contains only one statement");
43
44 match &mut stmt {
45 Statement::CreateTable { name, .. }
46 | Statement::CreateView { name, .. }
47 | Statement::CreateIndex { name, .. }
48 | Statement::CreateSource {
49 stmt: CreateSourceStatement {
50 source_name: name, ..
51 },
52 }
53 | Statement::CreateSubscription {
54 stmt:
55 CreateSubscriptionStatement {
56 subscription_name: name,
57 ..
58 },
59 }
60 | Statement::CreateSink {
61 stmt: CreateSinkStatement {
62 sink_name: name, ..
63 },
64 } => replace_table_name(name, new_name),
65 _ => unreachable!(),
66 };
67
68 stmt.to_string()
69}
70
71pub fn alter_relation_rename_refs(definition: &str, from: &str, to: &str) -> String {
74 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
75 let mut stmt = ast
76 .into_iter()
77 .exactly_one()
78 .expect("should contains only one statement");
79
80 match &mut stmt {
81 Statement::CreateTable {
82 query: Some(query), ..
83 }
84 | Statement::CreateView { query, .. }
85 | Statement::Query(query) | Statement::CreateSink {
87 stmt:
88 CreateSinkStatement {
89 sink_from: CreateSink::AsQuery(query),
90 into_table_name: None,
91 ..
92 },
93 } => {
94 QueryRewriter::rewrite_query(query, from, to);
95 }
96 Statement::CreateIndex { table_name, .. }
97 | Statement::CreateSink {
98 stmt:
99 CreateSinkStatement {
100 sink_from: CreateSink::From(table_name),
101 into_table_name: None,
102 ..
103 },
104 }| Statement::CreateSubscription {
105 stmt:
106 CreateSubscriptionStatement {
107 subscription_from: table_name,
108 ..
109 },
110 } | Statement::CreateTable {
111 cdc_table_info:
112 Some(CdcTableInfo {
113 source_name: table_name,
114 ..
115 }),
116 ..
117 } => replace_table_name(table_name, to),
118 Statement::CreateSink {
119 stmt: CreateSinkStatement {
120 sink_from,
121 into_table_name: Some(table_name),
122 ..
123 }
124 } => {
125 let idx = table_name.0.len() - 1;
126 if table_name.0[idx].real_value() == from {
127 table_name.0[idx] = Ident::from_real_value(to);
128 } else {
129 match sink_from {
130 CreateSink::From(table_name) => replace_table_name(table_name, to),
131 CreateSink::AsQuery(query) => QueryRewriter::rewrite_query(query, from, to),
132 }
133 }
134 }
135 _ => unreachable!(),
136 };
137 stmt.to_string()
138}
139
140fn replace_table_name(table_name: &mut ObjectName, to: &str) {
143 let idx = table_name.0.len() - 1;
144 table_name.0[idx] = Ident::from_real_value(to);
145}
146
147struct QueryRewriter<'a> {
150 from: &'a str,
151 to: &'a str,
152}
153
154impl QueryRewriter<'_> {
155 fn rewrite_query(query: &mut Query, from: &str, to: &str) {
156 let rewriter = QueryRewriter { from, to };
157 rewriter.visit_query(query)
158 }
159
160 fn visit_query(&self, query: &mut Query) {
162 if let Some(with) = &mut query.with {
163 for cte_table in &mut with.cte_tables {
164 match &mut cte_table.cte_inner {
165 risingwave_sqlparser::ast::CteInner::Query(query) => self.visit_query(query),
166 risingwave_sqlparser::ast::CteInner::ChangeLog(name) => {
167 let idx = name.0.len() - 1;
168 if name.0[idx].real_value() == self.from {
169 replace_table_name(name, self.to);
170 }
171 }
172 }
173 }
174 }
175 self.visit_set_expr(&mut query.body);
176 for expr in &mut query.order_by {
177 self.visit_expr(&mut expr.expr);
178 }
179 }
180
181 fn visit_table_factor(&self, table_factor: &mut TableFactor) {
193 match table_factor {
194 TableFactor::Table { name, alias, .. } => {
195 let idx = name.0.len() - 1;
196 if name.0[idx].real_value() == self.from {
197 if alias.is_none() {
198 *alias = Some(TableAlias {
199 name: Ident::from_real_value(self.from),
200 columns: vec![],
201 });
202 }
203 name.0[idx] = Ident::from_real_value(self.to);
204 }
205 }
206 TableFactor::Derived { subquery, .. } => self.visit_query(subquery),
207 TableFactor::TableFunction { args, .. } => {
208 for arg in args {
209 self.visit_function_arg(arg);
210 }
211 }
212 TableFactor::NestedJoin(table_with_joins) => {
213 self.visit_table_with_joins(table_with_joins);
214 }
215 }
216 }
217
218 fn visit_table_with_joins(&self, table_with_joins: &mut TableWithJoins) {
220 self.visit_table_factor(&mut table_with_joins.relation);
221 for join in &mut table_with_joins.joins {
222 self.visit_table_factor(&mut join.relation);
223 }
224 }
225
226 fn visit_set_expr(&self, set_expr: &mut SetExpr) {
228 match set_expr {
229 SetExpr::Select(select) => {
230 if let Distinct::DistinctOn(exprs) = &mut select.distinct {
231 for expr in exprs {
232 self.visit_expr(expr);
233 }
234 }
235 for select_item in &mut select.projection {
236 self.visit_select_item(select_item);
237 }
238 for from_item in &mut select.from {
239 self.visit_table_with_joins(from_item);
240 }
241 if let Some(where_clause) = &mut select.selection {
242 self.visit_expr(where_clause);
243 }
244 for expr in &mut select.group_by {
245 self.visit_expr(expr);
246 }
247 if let Some(having) = &mut select.having {
248 self.visit_expr(having);
249 }
250 for named_window in &mut select.window {
251 for expr in &mut named_window.window_spec.partition_by {
252 self.visit_expr(expr);
253 }
254 for expr in &mut named_window.window_spec.order_by {
255 self.visit_expr(&mut expr.expr);
256 }
257 }
258 }
259 SetExpr::Query(query) => self.visit_query(query),
260 SetExpr::SetOperation { left, right, .. } => {
261 self.visit_set_expr(left);
262 self.visit_set_expr(right);
263 }
264 SetExpr::Values(_) => {}
265 }
266 }
267
268 fn visit_function_arg(&self, function_arg: &mut FunctionArg) {
270 match function_arg {
271 FunctionArg::Unnamed(arg) | FunctionArg::Named { arg, .. } => match arg {
272 FunctionArgExpr::Expr(expr) | FunctionArgExpr::ExprQualifiedWildcard(expr, _) => {
273 self.visit_expr(expr)
274 }
275 FunctionArgExpr::QualifiedWildcard(_, None) | FunctionArgExpr::Wildcard(None) => {}
276 FunctionArgExpr::QualifiedWildcard(_, Some(exprs))
277 | FunctionArgExpr::Wildcard(Some(exprs)) => {
278 for expr in exprs {
279 self.visit_expr(expr);
280 }
281 }
282 FunctionArgExpr::SecretRef(_) => {}
283 },
284 }
285 }
286
287 fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) {
288 for arg in &mut arg_list.args {
289 self.visit_function_arg(arg);
290 }
291 for expr in &mut arg_list.order_by {
292 self.visit_expr(&mut expr.expr)
293 }
294 }
295
296 fn visit_function(&self, function: &mut Function) {
298 self.visit_function_arg_list(&mut function.arg_list);
299 if let Some(over) = &mut function.over {
300 match over {
301 Window::Spec(window) => {
302 for expr in &mut window.partition_by {
303 self.visit_expr(expr);
304 }
305 for expr in &mut window.order_by {
306 self.visit_expr(&mut expr.expr);
307 }
308 }
309 Window::Name(_) => {
310 }
312 }
313 }
314 }
315
316 fn visit_expr(&self, expr: &mut Expr) {
318 match expr {
319 Expr::FieldIdentifier(expr, ..)
320 | Expr::IsNull(expr)
321 | Expr::IsNotNull(expr)
322 | Expr::IsTrue(expr)
323 | Expr::IsNotTrue(expr)
324 | Expr::IsFalse(expr)
325 | Expr::IsNotFalse(expr)
326 | Expr::IsUnknown(expr)
327 | Expr::IsNotUnknown(expr)
328 | Expr::IsJson { expr, .. }
329 | Expr::InList { expr, .. }
330 | Expr::SomeOp(expr)
331 | Expr::AllOp(expr)
332 | Expr::UnaryOp { expr, .. }
333 | Expr::Cast { expr, .. }
334 | Expr::TryCast { expr, .. }
335 | Expr::AtTimeZone {
336 timestamp: expr, ..
337 }
338 | Expr::Extract { expr, .. }
339 | Expr::Substring { expr, .. }
340 | Expr::Overlay { expr, .. }
341 | Expr::Trim { expr, .. }
342 | Expr::Nested(expr)
343 | Expr::Index { obj: expr, .. }
344 | Expr::ArrayRangeIndex { obj: expr, .. } => self.visit_expr(expr),
345
346 Expr::Position { substring, string } => {
347 self.visit_expr(substring);
348 self.visit_expr(string);
349 }
350
351 Expr::InSubquery { expr, subquery, .. } => {
352 self.visit_expr(expr);
353 self.visit_query(subquery);
354 }
355 Expr::Between {
356 expr, low, high, ..
357 } => {
358 self.visit_expr(expr);
359 self.visit_expr(low);
360 self.visit_expr(high);
361 }
362 Expr::Like {
363 expr, pattern: pat, ..
364 } => {
365 self.visit_expr(expr);
366 self.visit_expr(pat);
367 }
368 Expr::ILike {
369 expr, pattern: pat, ..
370 } => {
371 self.visit_expr(expr);
372 self.visit_expr(pat);
373 }
374 Expr::SimilarTo {
375 expr, pattern: pat, ..
376 } => {
377 self.visit_expr(expr);
378 self.visit_expr(pat);
379 }
380
381 Expr::IsDistinctFrom(expr1, expr2)
382 | Expr::IsNotDistinctFrom(expr1, expr2)
383 | Expr::BinaryOp {
384 left: expr1,
385 right: expr2,
386 ..
387 } => {
388 self.visit_expr(expr1);
389 self.visit_expr(expr2);
390 }
391 Expr::Function(function) => self.visit_function(function),
392 Expr::Exists(query) | Expr::Subquery(query) | Expr::ArraySubquery(query) => {
393 self.visit_query(query)
394 }
395
396 Expr::GroupingSets(exprs_vec) | Expr::Cube(exprs_vec) | Expr::Rollup(exprs_vec) => {
397 for exprs in exprs_vec {
398 for expr in exprs {
399 self.visit_expr(expr);
400 }
401 }
402 }
403
404 Expr::Row(exprs) | Expr::Array(Array { elem: exprs, .. }) => {
405 for expr in exprs {
406 self.visit_expr(expr);
407 }
408 }
409 Expr::Map { entries } => {
410 for (key, value) in entries {
411 self.visit_expr(key);
412 self.visit_expr(value);
413 }
414 }
415
416 Expr::LambdaFunction { body, args: _ } => self.visit_expr(body),
417
418 Expr::Identifier(_)
420 | Expr::CompoundIdentifier(_)
421 | Expr::Collate { .. }
422 | Expr::Value(_)
423 | Expr::Parameter { .. }
424 | Expr::TypedString { .. }
425 | Expr::Case { .. } => {}
426 }
427 }
428
429 fn visit_select_item(&self, select_item: &mut SelectItem) {
431 match select_item {
432 SelectItem::UnnamedExpr(expr)
433 | SelectItem::ExprQualifiedWildcard(expr, _)
434 | SelectItem::ExprWithAlias { expr, .. } => self.visit_expr(expr),
435 SelectItem::QualifiedWildcard(_, None) | SelectItem::Wildcard(None) => {}
436 SelectItem::QualifiedWildcard(_, Some(exprs)) | SelectItem::Wildcard(Some(exprs)) => {
437 for expr in exprs {
438 self.visit_expr(expr);
439 }
440 }
441 }
442 }
443}
444
445pub struct IndexItemRewriter {
448 pub original_columns: Vec<PbColumnDesc>,
449 pub new_columns: Vec<PbColumnDesc>,
450}
451
452impl IndexItemRewriter {
453 pub fn rewrite_expr(&self, expr: &mut ExprNode) {
454 let rex_node = expr.rex_node.as_mut().unwrap();
455 match rex_node {
456 RexNode::InputRef(idx) => {
457 let old_idx = *idx as usize;
458 let original_column = &self.original_columns[old_idx];
459 let (new_idx, new_column) = self
460 .new_columns
461 .iter()
462 .find_position(|c| c.column_id == original_column.column_id)
463 .expect("should already checked index referencing column still exists");
464 *idx = new_idx as u32;
465
466 if new_column.column_type != original_column.column_type {
470 let old_type = original_column.column_type.clone().unwrap();
471 let new_type = new_column.column_type.clone().unwrap();
472
473 assert_eq!(&old_type, expr.return_type.as_ref().unwrap());
474 expr.return_type = Some(new_type); let new_expr_node = ExprNode {
477 function_type: expr_node::Type::CompositeCast as _,
478 return_type: Some(old_type),
479 rex_node: RexNode::FuncCall(FunctionCall {
480 children: vec![expr.clone()],
481 })
482 .into(),
483 };
484
485 *expr = new_expr_node;
486 }
487 }
488 RexNode::Constant(_) => {}
489 RexNode::Udf(udf) => self.rewrite_udf(udf),
490 RexNode::FuncCall(function_call) => self.rewrite_function_call(function_call),
491 RexNode::Now(_) | RexNode::SecretRef(_) => {}
492 }
493 }
494
495 fn rewrite_udf(&self, udf: &mut UserDefinedFunction) {
496 udf.children
497 .iter_mut()
498 .for_each(|expr| self.rewrite_expr(expr));
499 }
500
501 fn rewrite_function_call(&self, function_call: &mut FunctionCall) {
502 function_call
503 .children
504 .iter_mut()
505 .for_each(|expr| self.rewrite_expr(expr));
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_alter_table_rename() {
515 let definition = "CREATE TABLE foo (a int, b int)";
516 let new_name = "bar";
517 let expected = "CREATE TABLE bar (a INT, b INT)";
518 let actual = alter_relation_rename(definition, new_name);
519 assert_eq!(expected, actual);
520 }
521
522 #[test]
523 fn test_rename_index_refs() {
524 let definition = "CREATE INDEX idx1 ON foo(v1 DESC, v2)";
525 let from = "foo";
526 let to = "bar";
527 let expected = "CREATE INDEX idx1 ON bar(v1 DESC, v2)";
528 let actual = alter_relation_rename_refs(definition, from, to);
529 assert_eq!(expected, actual);
530 }
531
532 #[test]
533 fn test_rename_sink_refs() {
534 let definition =
535 "CREATE SINK sink_t FROM foo WITH (connector = 'kafka', format = 'append_only')";
536 let from = "foo";
537 let to = "bar";
538 let expected =
539 "CREATE SINK sink_t FROM bar WITH (connector = 'kafka', format = 'append_only')";
540 let actual = alter_relation_rename_refs(definition, from, to);
541 assert_eq!(expected, actual);
542 }
543
544 #[test]
545 fn test_rename_with_alias_refs() {
546 let definition =
547 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM foo";
548 let from = "foo";
549 let to = "bar";
550 let expected =
551 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM bar AS foo";
552 let actual = alter_relation_rename_refs(definition, from, to);
553 assert_eq!(expected, actual);
554
555 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
556 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM bar AS foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
557 let actual = alter_relation_rename_refs(definition, from, to);
558 assert_eq!(expected, actual);
559
560 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM foo AS bar WHERE bar.v1 = 1";
561 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM bar AS bar WHERE bar.v1 = 1";
562 let actual = alter_relation_rename_refs(definition, from, to);
563 assert_eq!(expected, actual);
564 }
565
566 #[test]
567 fn test_rename_with_complex_funcs() {
568 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
569 agg1(\
570 foo.v1, func2(foo.v2) \
571 ORDER BY \
572 (SELECT foo.v3 FROM foo), \
573 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM foo) ORDER BY (SELECT foo.v6 FROM foo)) FROM foo)\
574 ) \
575 FROM foo";
576 let from = "foo";
577 let to = "bar";
578 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
579 agg1(\
580 foo.v1, func2(foo.v2) \
581 ORDER BY \
582 (SELECT foo.v3 FROM bar AS foo), \
583 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM bar AS foo) ORDER BY (SELECT foo.v6 FROM bar AS foo)) FROM bar AS foo)\
584 ) \
585 FROM bar AS foo";
586 let actual = alter_relation_rename_refs(definition, from, to);
587 assert_eq!(expected, actual);
588 }
589}