1use itertools::Itertools;
16use risingwave_pb::expr::expr_node::{self, RexNode};
17use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction};
18use risingwave_pb::plan_common::PbColumnDesc;
19use risingwave_sqlparser::ast::{
20 Array, CreateSink, CreateSinkStatement, CreateSourceStatement, CreateSubscriptionStatement,
21 Distinct, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident, ObjectName,
22 Query, SelectItem, SetExpr, Statement, TableAlias, TableFactor, TableWithJoins, Window,
23};
24use risingwave_sqlparser::parser::Parser;
25
26pub fn alter_relation_rename(definition: &str, new_name: &str) -> String {
31 if definition.is_empty() {
34 tracing::warn!("found empty definition when renaming relation, ignored.");
35 return definition.into();
36 }
37 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
38 let mut stmt = ast
39 .into_iter()
40 .exactly_one()
41 .expect("should contains only one statement");
42
43 match &mut stmt {
44 Statement::CreateTable { name, .. }
45 | Statement::CreateView { name, .. }
46 | Statement::CreateIndex { name, .. }
47 | Statement::CreateSource {
48 stmt: CreateSourceStatement {
49 source_name: name, ..
50 },
51 }
52 | Statement::CreateSubscription {
53 stmt:
54 CreateSubscriptionStatement {
55 subscription_name: name,
56 ..
57 },
58 }
59 | Statement::CreateSink {
60 stmt: CreateSinkStatement {
61 sink_name: name, ..
62 },
63 } => replace_table_name(name, new_name),
64 _ => unreachable!(),
65 };
66
67 stmt.to_string()
68}
69
70pub fn alter_relation_rename_refs(definition: &str, from: &str, to: &str) -> String {
73 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
74 let mut stmt = ast
75 .into_iter()
76 .exactly_one()
77 .expect("should contains only one statement");
78
79 match &mut stmt {
80 Statement::CreateTable {
81 query: Some(query), ..
82 }
83 | Statement::CreateView { query, .. }
84 | Statement::Query(query) | Statement::CreateSink {
86 stmt:
87 CreateSinkStatement {
88 sink_from: CreateSink::AsQuery(query),
89 into_table_name: None,
90 ..
91 },
92 } => {
93 QueryRewriter::rewrite_query(query, from, to);
94 }
95 Statement::CreateIndex { table_name, .. }
96 | Statement::CreateSink {
97 stmt:
98 CreateSinkStatement {
99 sink_from: CreateSink::From(table_name),
100 into_table_name: None,
101 ..
102 },
103 }| Statement::CreateSubscription {
104 stmt:
105 CreateSubscriptionStatement {
106 subscription_from: table_name,
107 ..
108 },
109 } => replace_table_name(table_name, to),
110 Statement::CreateSink {
111 stmt: CreateSinkStatement {
112 sink_from,
113 into_table_name: Some(table_name),
114 ..
115 }
116 } => {
117 let idx = table_name.0.len() - 1;
118 if table_name.0[idx].real_value() == from {
119 table_name.0[idx] = Ident::from_real_value(to);
120 } else {
121 match sink_from {
122 CreateSink::From(table_name) => replace_table_name(table_name, to),
123 CreateSink::AsQuery(query) => QueryRewriter::rewrite_query(query, from, to),
124 }
125 }
126 }
127 _ => unreachable!(),
128 };
129 stmt.to_string()
130}
131
132fn replace_table_name(table_name: &mut ObjectName, to: &str) {
135 let idx = table_name.0.len() - 1;
136 table_name.0[idx] = Ident::from_real_value(to);
137}
138
139struct QueryRewriter<'a> {
142 from: &'a str,
143 to: &'a str,
144}
145
146impl QueryRewriter<'_> {
147 fn rewrite_query(query: &mut Query, from: &str, to: &str) {
148 let rewriter = QueryRewriter { from, to };
149 rewriter.visit_query(query)
150 }
151
152 fn visit_query(&self, query: &mut Query) {
154 if let Some(with) = &mut query.with {
155 for cte_table in &mut with.cte_tables {
156 match &mut cte_table.cte_inner {
157 risingwave_sqlparser::ast::CteInner::Query(query) => self.visit_query(query),
158 risingwave_sqlparser::ast::CteInner::ChangeLog(name) => {
159 let idx = name.0.len() - 1;
160 if name.0[idx].real_value() == self.from {
161 replace_table_name(name, self.to);
162 }
163 }
164 }
165 }
166 }
167 self.visit_set_expr(&mut query.body);
168 for expr in &mut query.order_by {
169 self.visit_expr(&mut expr.expr);
170 }
171 }
172
173 fn visit_table_factor(&self, table_factor: &mut TableFactor) {
185 match table_factor {
186 TableFactor::Table { name, alias, .. } => {
187 let idx = name.0.len() - 1;
188 if name.0[idx].real_value() == self.from {
189 if alias.is_none() {
190 *alias = Some(TableAlias {
191 name: Ident::from_real_value(self.from),
192 columns: vec![],
193 });
194 }
195 name.0[idx] = Ident::from_real_value(self.to);
196 }
197 }
198 TableFactor::Derived { subquery, .. } => self.visit_query(subquery),
199 TableFactor::TableFunction { args, .. } => {
200 for arg in args {
201 self.visit_function_arg(arg);
202 }
203 }
204 TableFactor::NestedJoin(table_with_joins) => {
205 self.visit_table_with_joins(table_with_joins);
206 }
207 }
208 }
209
210 fn visit_table_with_joins(&self, table_with_joins: &mut TableWithJoins) {
212 self.visit_table_factor(&mut table_with_joins.relation);
213 for join in &mut table_with_joins.joins {
214 self.visit_table_factor(&mut join.relation);
215 }
216 }
217
218 fn visit_set_expr(&self, set_expr: &mut SetExpr) {
220 match set_expr {
221 SetExpr::Select(select) => {
222 if let Distinct::DistinctOn(exprs) = &mut select.distinct {
223 for expr in exprs {
224 self.visit_expr(expr);
225 }
226 }
227 for select_item in &mut select.projection {
228 self.visit_select_item(select_item);
229 }
230 for from_item in &mut select.from {
231 self.visit_table_with_joins(from_item);
232 }
233 if let Some(where_clause) = &mut select.selection {
234 self.visit_expr(where_clause);
235 }
236 for expr in &mut select.group_by {
237 self.visit_expr(expr);
238 }
239 if let Some(having) = &mut select.having {
240 self.visit_expr(having);
241 }
242 for named_window in &mut select.window {
243 for expr in &mut named_window.window_spec.partition_by {
244 self.visit_expr(expr);
245 }
246 for expr in &mut named_window.window_spec.order_by {
247 self.visit_expr(&mut expr.expr);
248 }
249 }
250 }
251 SetExpr::Query(query) => self.visit_query(query),
252 SetExpr::SetOperation { left, right, .. } => {
253 self.visit_set_expr(left);
254 self.visit_set_expr(right);
255 }
256 SetExpr::Values(_) => {}
257 }
258 }
259
260 fn visit_function_arg(&self, function_arg: &mut FunctionArg) {
262 match function_arg {
263 FunctionArg::Unnamed(arg) | FunctionArg::Named { arg, .. } => match arg {
264 FunctionArgExpr::Expr(expr) | FunctionArgExpr::ExprQualifiedWildcard(expr, _) => {
265 self.visit_expr(expr)
266 }
267 FunctionArgExpr::QualifiedWildcard(_, None) | FunctionArgExpr::Wildcard(None) => {}
268 FunctionArgExpr::QualifiedWildcard(_, Some(exprs))
269 | FunctionArgExpr::Wildcard(Some(exprs)) => {
270 for expr in exprs {
271 self.visit_expr(expr);
272 }
273 }
274 },
275 }
276 }
277
278 fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) {
279 for arg in &mut arg_list.args {
280 self.visit_function_arg(arg);
281 }
282 for expr in &mut arg_list.order_by {
283 self.visit_expr(&mut expr.expr)
284 }
285 }
286
287 fn visit_function(&self, function: &mut Function) {
289 self.visit_function_arg_list(&mut function.arg_list);
290 if let Some(over) = &mut function.over {
291 match over {
292 Window::Spec(window) => {
293 for expr in &mut window.partition_by {
294 self.visit_expr(expr);
295 }
296 for expr in &mut window.order_by {
297 self.visit_expr(&mut expr.expr);
298 }
299 }
300 Window::Name(_) => {
301 }
303 }
304 }
305 }
306
307 fn visit_expr(&self, expr: &mut Expr) {
309 match expr {
310 Expr::FieldIdentifier(expr, ..)
311 | Expr::IsNull(expr)
312 | Expr::IsNotNull(expr)
313 | Expr::IsTrue(expr)
314 | Expr::IsNotTrue(expr)
315 | Expr::IsFalse(expr)
316 | Expr::IsNotFalse(expr)
317 | Expr::IsUnknown(expr)
318 | Expr::IsNotUnknown(expr)
319 | Expr::IsJson { expr, .. }
320 | Expr::InList { expr, .. }
321 | Expr::SomeOp(expr)
322 | Expr::AllOp(expr)
323 | Expr::UnaryOp { expr, .. }
324 | Expr::Cast { expr, .. }
325 | Expr::TryCast { expr, .. }
326 | Expr::AtTimeZone {
327 timestamp: expr, ..
328 }
329 | Expr::Extract { expr, .. }
330 | Expr::Substring { expr, .. }
331 | Expr::Overlay { expr, .. }
332 | Expr::Trim { expr, .. }
333 | Expr::Nested(expr)
334 | Expr::Index { obj: expr, .. }
335 | Expr::ArrayRangeIndex { obj: expr, .. } => self.visit_expr(expr),
336
337 Expr::Position { substring, string } => {
338 self.visit_expr(substring);
339 self.visit_expr(string);
340 }
341
342 Expr::InSubquery { expr, subquery, .. } => {
343 self.visit_expr(expr);
344 self.visit_query(subquery);
345 }
346 Expr::Between {
347 expr, low, high, ..
348 } => {
349 self.visit_expr(expr);
350 self.visit_expr(low);
351 self.visit_expr(high);
352 }
353 Expr::Like {
354 expr, pattern: pat, ..
355 } => {
356 self.visit_expr(expr);
357 self.visit_expr(pat);
358 }
359 Expr::ILike {
360 expr, pattern: pat, ..
361 } => {
362 self.visit_expr(expr);
363 self.visit_expr(pat);
364 }
365 Expr::SimilarTo {
366 expr, pattern: pat, ..
367 } => {
368 self.visit_expr(expr);
369 self.visit_expr(pat);
370 }
371
372 Expr::IsDistinctFrom(expr1, expr2)
373 | Expr::IsNotDistinctFrom(expr1, expr2)
374 | Expr::BinaryOp {
375 left: expr1,
376 right: expr2,
377 ..
378 } => {
379 self.visit_expr(expr1);
380 self.visit_expr(expr2);
381 }
382 Expr::Function(function) => self.visit_function(function),
383 Expr::Exists(query) | Expr::Subquery(query) | Expr::ArraySubquery(query) => {
384 self.visit_query(query)
385 }
386
387 Expr::GroupingSets(exprs_vec) | Expr::Cube(exprs_vec) | Expr::Rollup(exprs_vec) => {
388 for exprs in exprs_vec {
389 for expr in exprs {
390 self.visit_expr(expr);
391 }
392 }
393 }
394
395 Expr::Row(exprs) | Expr::Array(Array { elem: exprs, .. }) => {
396 for expr in exprs {
397 self.visit_expr(expr);
398 }
399 }
400 Expr::Map { entries } => {
401 for (key, value) in entries {
402 self.visit_expr(key);
403 self.visit_expr(value);
404 }
405 }
406
407 Expr::LambdaFunction { body, args: _ } => self.visit_expr(body),
408
409 Expr::Identifier(_)
411 | Expr::CompoundIdentifier(_)
412 | Expr::Collate { .. }
413 | Expr::Value(_)
414 | Expr::Parameter { .. }
415 | Expr::TypedString { .. }
416 | Expr::Case { .. } => {}
417 }
418 }
419
420 fn visit_select_item(&self, select_item: &mut SelectItem) {
422 match select_item {
423 SelectItem::UnnamedExpr(expr)
424 | SelectItem::ExprQualifiedWildcard(expr, _)
425 | SelectItem::ExprWithAlias { expr, .. } => self.visit_expr(expr),
426 SelectItem::QualifiedWildcard(_, None) | SelectItem::Wildcard(None) => {}
427 SelectItem::QualifiedWildcard(_, Some(exprs)) | SelectItem::Wildcard(Some(exprs)) => {
428 for expr in exprs {
429 self.visit_expr(expr);
430 }
431 }
432 }
433 }
434}
435
436pub struct IndexItemRewriter {
439 pub original_columns: Vec<PbColumnDesc>,
440 pub new_columns: Vec<PbColumnDesc>,
441}
442
443impl IndexItemRewriter {
444 pub fn rewrite_expr(&self, expr: &mut ExprNode) {
445 let rex_node = expr.rex_node.as_mut().unwrap();
446 match rex_node {
447 RexNode::InputRef(idx) => {
448 let old_idx = *idx as usize;
449 let original_column = &self.original_columns[old_idx];
450 let (new_idx, new_column) = self
451 .new_columns
452 .iter()
453 .find_position(|c| c.column_id == original_column.column_id)
454 .expect("should already checked index referencing column still exists");
455 *idx = new_idx as u32;
456
457 if new_column.column_type != original_column.column_type {
461 let old_type = original_column.column_type.clone().unwrap();
462 let new_type = new_column.column_type.clone().unwrap();
463
464 assert_eq!(&old_type, expr.return_type.as_ref().unwrap());
465 expr.return_type = Some(new_type); let new_expr_node = ExprNode {
468 function_type: expr_node::Type::CompositeCast as _,
469 return_type: Some(old_type),
470 rex_node: RexNode::FuncCall(FunctionCall {
471 children: vec![expr.clone()],
472 })
473 .into(),
474 };
475
476 *expr = new_expr_node;
477 }
478 }
479 RexNode::Constant(_) => {}
480 RexNode::Udf(udf) => self.rewrite_udf(udf),
481 RexNode::FuncCall(function_call) => self.rewrite_function_call(function_call),
482 RexNode::Now(_) => {}
483 }
484 }
485
486 fn rewrite_udf(&self, udf: &mut UserDefinedFunction) {
487 udf.children
488 .iter_mut()
489 .for_each(|expr| self.rewrite_expr(expr));
490 }
491
492 fn rewrite_function_call(&self, function_call: &mut FunctionCall) {
493 function_call
494 .children
495 .iter_mut()
496 .for_each(|expr| self.rewrite_expr(expr));
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_alter_table_rename() {
506 let definition = "CREATE TABLE foo (a int, b int)";
507 let new_name = "bar";
508 let expected = "CREATE TABLE bar (a INT, b INT)";
509 let actual = alter_relation_rename(definition, new_name);
510 assert_eq!(expected, actual);
511 }
512
513 #[test]
514 fn test_rename_index_refs() {
515 let definition = "CREATE INDEX idx1 ON foo(v1 DESC, v2)";
516 let from = "foo";
517 let to = "bar";
518 let expected = "CREATE INDEX idx1 ON bar(v1 DESC, v2)";
519 let actual = alter_relation_rename_refs(definition, from, to);
520 assert_eq!(expected, actual);
521 }
522
523 #[test]
524 fn test_rename_sink_refs() {
525 let definition =
526 "CREATE SINK sink_t FROM foo WITH (connector = 'kafka', format = 'append_only')";
527 let from = "foo";
528 let to = "bar";
529 let expected =
530 "CREATE SINK sink_t FROM bar WITH (connector = 'kafka', format = 'append_only')";
531 let actual = alter_relation_rename_refs(definition, from, to);
532 assert_eq!(expected, actual);
533 }
534
535 #[test]
536 fn test_rename_with_alias_refs() {
537 let definition =
538 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM foo";
539 let from = "foo";
540 let to = "bar";
541 let expected =
542 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM bar AS foo";
543 let actual = alter_relation_rename_refs(definition, from, to);
544 assert_eq!(expected, actual);
545
546 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
547 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM bar AS foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
548 let actual = alter_relation_rename_refs(definition, from, to);
549 assert_eq!(expected, actual);
550
551 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM foo AS bar WHERE bar.v1 = 1";
552 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM bar AS bar WHERE bar.v1 = 1";
553 let actual = alter_relation_rename_refs(definition, from, to);
554 assert_eq!(expected, actual);
555 }
556
557 #[test]
558 fn test_rename_with_complex_funcs() {
559 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
560 agg1(\
561 foo.v1, func2(foo.v2) \
562 ORDER BY \
563 (SELECT foo.v3 FROM foo), \
564 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM foo) ORDER BY (SELECT foo.v6 FROM foo)) FROM foo)\
565 ) \
566 FROM foo";
567 let from = "foo";
568 let to = "bar";
569 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
570 agg1(\
571 foo.v1, func2(foo.v2) \
572 ORDER BY \
573 (SELECT foo.v3 FROM bar AS foo), \
574 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM bar AS foo) ORDER BY (SELECT foo.v6 FROM bar AS foo)) FROM bar AS foo)\
575 ) \
576 FROM bar AS foo";
577 let actual = alter_relation_rename_refs(definition, from, to);
578 assert_eq!(expected, actual);
579 }
580}