1use itertools::Itertools;
16use risingwave_pb::expr::expr_node::{self, RexNode};
17use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction};
18use risingwave_pb::plan_common::PbColumnDesc;
19use risingwave_sqlparser::ast::{
20 Array, CdcTableInfo, CreateSink, CreateSinkStatement, CreateSourceStatement,
21 CreateSubscriptionStatement, Distinct, Expr, Function, FunctionArg, FunctionArgExpr,
22 FunctionArgList, Ident, ObjectName, Query, SelectItem, SetExpr, Statement, TableAlias,
23 TableFactor, TableWithJoins, Window,
24};
25use risingwave_sqlparser::parser::Parser;
26
27pub fn alter_relation_rename(definition: &str, new_name: &str) -> String {
32 if definition.is_empty() {
35 tracing::warn!("found empty definition when renaming relation, ignored.");
36 return definition.into();
37 }
38 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
39 let mut stmt =
40 Itertools::exactly_one(ast.into_iter()).expect("should contains only one statement");
41
42 match &mut stmt {
43 Statement::CreateTable { name, .. }
44 | Statement::CreateView { name, .. }
45 | Statement::CreateIndex { name, .. }
46 | Statement::CreateSource {
47 stmt: CreateSourceStatement {
48 source_name: name, ..
49 },
50 }
51 | Statement::CreateSubscription {
52 stmt:
53 CreateSubscriptionStatement {
54 subscription_name: name,
55 ..
56 },
57 }
58 | Statement::CreateSink {
59 stmt: CreateSinkStatement {
60 sink_name: name, ..
61 },
62 } => replace_table_name(name, new_name),
63 _ => unreachable!(),
64 };
65
66 stmt.to_string()
67}
68
69pub fn alter_relation_rename_refs(definition: &str, from: &str, to: &str) -> String {
72 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
73 let mut stmt =
74 Itertools::exactly_one(ast.into_iter()).expect("should contains only one statement");
75
76 match &mut stmt {
77 Statement::CreateTable {
78 query: Some(query), ..
79 }
80 | Statement::CreateView { query, .. }
81 | Statement::Query(query) | Statement::CreateSink {
83 stmt:
84 CreateSinkStatement {
85 sink_from: CreateSink::AsQuery(query),
86 into_table_name: None,
87 ..
88 },
89 } => {
90 QueryRewriter::rewrite_query(query, from, to);
91 }
92 Statement::CreateIndex { table_name, .. }
93 | Statement::CreateSink {
94 stmt:
95 CreateSinkStatement {
96 sink_from: CreateSink::From(table_name),
97 into_table_name: None,
98 ..
99 },
100 }| Statement::CreateSubscription {
101 stmt:
102 CreateSubscriptionStatement {
103 subscription_from: table_name,
104 ..
105 },
106 } | Statement::CreateTable {
107 cdc_table_info:
108 Some(CdcTableInfo {
109 source_name: table_name,
110 ..
111 }),
112 ..
113 } => replace_table_name(table_name, to),
114 Statement::CreateSink {
115 stmt: CreateSinkStatement {
116 sink_from,
117 into_table_name: Some(table_name),
118 ..
119 }
120 } => {
121 let idx = table_name.0.len() - 1;
122 if table_name.0[idx].real_value() == from {
123 table_name.0[idx] = Ident::from_real_value(to);
124 } else {
125 match sink_from {
126 CreateSink::From(table_name) => replace_table_name(table_name, to),
127 CreateSink::AsQuery(query) => QueryRewriter::rewrite_query(query, from, to),
128 }
129 }
130 }
131 _ => unreachable!(),
132 };
133 stmt.to_string()
134}
135
136fn replace_table_name(table_name: &mut ObjectName, to: &str) {
139 let idx = table_name.0.len() - 1;
140 table_name.0[idx] = Ident::from_real_value(to);
141}
142
143struct QueryRewriter<'a> {
146 from: &'a str,
147 to: &'a str,
148}
149
150impl QueryRewriter<'_> {
151 fn rewrite_query(query: &mut Query, from: &str, to: &str) {
152 let rewriter = QueryRewriter { from, to };
153 rewriter.visit_query(query)
154 }
155
156 fn visit_query(&self, query: &mut Query) {
158 if let Some(with) = &mut query.with {
159 for cte_table in &mut with.cte_tables {
160 match &mut cte_table.cte_inner {
161 risingwave_sqlparser::ast::CteInner::Query(query) => self.visit_query(query),
162 risingwave_sqlparser::ast::CteInner::ChangeLog(name) => {
163 let idx = name.0.len() - 1;
164 if name.0[idx].real_value() == self.from {
165 replace_table_name(name, self.to);
166 }
167 }
168 }
169 }
170 }
171 self.visit_set_expr(&mut query.body);
172 for expr in &mut query.order_by {
173 self.visit_expr(&mut expr.expr);
174 }
175 }
176
177 fn visit_table_factor(&self, table_factor: &mut TableFactor) {
189 match table_factor {
190 TableFactor::Table { name, alias, .. } => {
191 let idx = name.0.len() - 1;
192 if name.0[idx].real_value() == self.from {
193 if alias.is_none() {
194 *alias = Some(TableAlias {
195 name: Ident::from_real_value(self.from),
196 columns: vec![],
197 });
198 }
199 name.0[idx] = Ident::from_real_value(self.to);
200 }
201 }
202 TableFactor::Derived { subquery, .. } => self.visit_query(subquery),
203 TableFactor::TableFunction { args, .. } => {
204 for arg in args {
205 self.visit_function_arg(arg);
206 }
207 }
208 TableFactor::NestedJoin(table_with_joins) => {
209 self.visit_table_with_joins(table_with_joins);
210 }
211 }
212 }
213
214 fn visit_table_with_joins(&self, table_with_joins: &mut TableWithJoins) {
216 self.visit_table_factor(&mut table_with_joins.relation);
217 for join in &mut table_with_joins.joins {
218 self.visit_table_factor(&mut join.relation);
219 }
220 }
221
222 fn visit_set_expr(&self, set_expr: &mut SetExpr) {
224 match set_expr {
225 SetExpr::Select(select) => {
226 if let Distinct::DistinctOn(exprs) = &mut select.distinct {
227 for expr in exprs {
228 self.visit_expr(expr);
229 }
230 }
231 for select_item in &mut select.projection {
232 self.visit_select_item(select_item);
233 }
234 for from_item in &mut select.from {
235 self.visit_table_with_joins(from_item);
236 }
237 if let Some(where_clause) = &mut select.selection {
238 self.visit_expr(where_clause);
239 }
240 for expr in &mut select.group_by {
241 self.visit_expr(expr);
242 }
243 if let Some(having) = &mut select.having {
244 self.visit_expr(having);
245 }
246 for named_window in &mut select.window {
247 for expr in &mut named_window.window_spec.partition_by {
248 self.visit_expr(expr);
249 }
250 for expr in &mut named_window.window_spec.order_by {
251 self.visit_expr(&mut expr.expr);
252 }
253 }
254 }
255 SetExpr::Query(query) => self.visit_query(query),
256 SetExpr::SetOperation { left, right, .. } => {
257 self.visit_set_expr(left);
258 self.visit_set_expr(right);
259 }
260 SetExpr::Values(_) => {}
261 }
262 }
263
264 fn visit_function_arg(&self, function_arg: &mut FunctionArg) {
266 match function_arg {
267 FunctionArg::Unnamed(arg) | FunctionArg::Named { arg, .. } => match arg {
268 FunctionArgExpr::Expr(expr) | FunctionArgExpr::ExprQualifiedWildcard(expr, _) => {
269 self.visit_expr(expr)
270 }
271 FunctionArgExpr::QualifiedWildcard(_, None) | FunctionArgExpr::Wildcard(None) => {}
272 FunctionArgExpr::QualifiedWildcard(_, Some(exprs))
273 | FunctionArgExpr::Wildcard(Some(exprs)) => {
274 for expr in exprs {
275 self.visit_expr(expr);
276 }
277 }
278 FunctionArgExpr::SecretRef(_) => {}
279 },
280 }
281 }
282
283 fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) {
284 for arg in &mut arg_list.args {
285 self.visit_function_arg(arg);
286 }
287 for expr in &mut arg_list.order_by {
288 self.visit_expr(&mut expr.expr)
289 }
290 }
291
292 fn visit_function(&self, function: &mut Function) {
294 self.visit_function_arg_list(&mut function.arg_list);
295 if let Some(over) = &mut function.over {
296 match over {
297 Window::Spec(window) => {
298 for expr in &mut window.partition_by {
299 self.visit_expr(expr);
300 }
301 for expr in &mut window.order_by {
302 self.visit_expr(&mut expr.expr);
303 }
304 }
305 Window::Name(_) => {
306 }
308 }
309 }
310 }
311
312 fn visit_expr(&self, expr: &mut Expr) {
314 match expr {
315 Expr::FieldIdentifier(expr, ..)
316 | Expr::IsNull(expr)
317 | Expr::IsNotNull(expr)
318 | Expr::IsTrue(expr)
319 | Expr::IsNotTrue(expr)
320 | Expr::IsFalse(expr)
321 | Expr::IsNotFalse(expr)
322 | Expr::IsUnknown(expr)
323 | Expr::IsNotUnknown(expr)
324 | Expr::IsJson { expr, .. }
325 | Expr::InList { expr, .. }
326 | Expr::SomeOp(expr)
327 | Expr::AllOp(expr)
328 | Expr::UnaryOp { expr, .. }
329 | Expr::Cast { expr, .. }
330 | Expr::TryCast { expr, .. }
331 | Expr::AtTimeZone {
332 timestamp: expr, ..
333 }
334 | Expr::Extract { expr, .. }
335 | Expr::Substring { expr, .. }
336 | Expr::Overlay { expr, .. }
337 | Expr::Trim { expr, .. }
338 | Expr::Nested(expr)
339 | Expr::Index { obj: expr, .. }
340 | Expr::ArrayRangeIndex { obj: expr, .. } => self.visit_expr(expr),
341
342 Expr::Position { substring, string } => {
343 self.visit_expr(substring);
344 self.visit_expr(string);
345 }
346
347 Expr::InSubquery { expr, subquery, .. } => {
348 self.visit_expr(expr);
349 self.visit_query(subquery);
350 }
351 Expr::Between {
352 expr, low, high, ..
353 } => {
354 self.visit_expr(expr);
355 self.visit_expr(low);
356 self.visit_expr(high);
357 }
358 Expr::Like {
359 expr, pattern: pat, ..
360 } => {
361 self.visit_expr(expr);
362 self.visit_expr(pat);
363 }
364 Expr::ILike {
365 expr, pattern: pat, ..
366 } => {
367 self.visit_expr(expr);
368 self.visit_expr(pat);
369 }
370 Expr::SimilarTo {
371 expr, pattern: pat, ..
372 } => {
373 self.visit_expr(expr);
374 self.visit_expr(pat);
375 }
376
377 Expr::IsDistinctFrom(expr1, expr2)
378 | Expr::IsNotDistinctFrom(expr1, expr2)
379 | Expr::BinaryOp {
380 left: expr1,
381 right: expr2,
382 ..
383 } => {
384 self.visit_expr(expr1);
385 self.visit_expr(expr2);
386 }
387 Expr::Function(function) => self.visit_function(function),
388 Expr::Exists(query) | Expr::Subquery(query) | Expr::ArraySubquery(query) => {
389 self.visit_query(query)
390 }
391
392 Expr::GroupingSets(exprs_vec) | Expr::Cube(exprs_vec) | Expr::Rollup(exprs_vec) => {
393 for exprs in exprs_vec {
394 for expr in exprs {
395 self.visit_expr(expr);
396 }
397 }
398 }
399
400 Expr::Row(exprs) | Expr::Array(Array { elem: exprs, .. }) => {
401 for expr in exprs {
402 self.visit_expr(expr);
403 }
404 }
405 Expr::Map { entries } => {
406 for (key, value) in entries {
407 self.visit_expr(key);
408 self.visit_expr(value);
409 }
410 }
411
412 Expr::LambdaFunction { body, args: _ } => self.visit_expr(body),
413
414 Expr::Identifier(_)
416 | Expr::CompoundIdentifier(_)
417 | Expr::Collate { .. }
418 | Expr::Value(_)
419 | Expr::Parameter { .. }
420 | Expr::TypedString { .. }
421 | Expr::Case { .. } => {}
422 }
423 }
424
425 fn visit_select_item(&self, select_item: &mut SelectItem) {
427 match select_item {
428 SelectItem::UnnamedExpr(expr)
429 | SelectItem::ExprQualifiedWildcard(expr, _)
430 | SelectItem::ExprWithAlias { expr, .. } => self.visit_expr(expr),
431 SelectItem::QualifiedWildcard(_, None) | SelectItem::Wildcard(None) => {}
432 SelectItem::QualifiedWildcard(_, Some(exprs)) | SelectItem::Wildcard(Some(exprs)) => {
433 for expr in exprs {
434 self.visit_expr(expr);
435 }
436 }
437 }
438 }
439}
440
441pub struct IndexItemRewriter {
444 pub original_columns: Vec<PbColumnDesc>,
445 pub new_columns: Vec<PbColumnDesc>,
446}
447
448impl IndexItemRewriter {
449 pub fn rewrite_expr(&self, expr: &mut ExprNode) {
450 let rex_node = expr.rex_node.as_mut().unwrap();
451 match rex_node {
452 RexNode::InputRef(idx) => {
453 let old_idx = *idx as usize;
454 let original_column = &self.original_columns[old_idx];
455 let (new_idx, new_column) = self
456 .new_columns
457 .iter()
458 .find_position(|c| c.column_id == original_column.column_id)
459 .expect("should already checked index referencing column still exists");
460 *idx = new_idx as u32;
461
462 if new_column.column_type != original_column.column_type {
466 let old_type = original_column.column_type.clone().unwrap();
467 let new_type = new_column.column_type.clone().unwrap();
468
469 assert_eq!(&old_type, expr.return_type.as_ref().unwrap());
470 expr.return_type = Some(new_type); let new_expr_node = ExprNode {
473 function_type: expr_node::Type::CompositeCast as _,
474 return_type: Some(old_type),
475 rex_node: RexNode::FuncCall(FunctionCall {
476 children: vec![expr.clone()],
477 })
478 .into(),
479 };
480
481 *expr = new_expr_node;
482 }
483 }
484 RexNode::Constant(_) => {}
485 RexNode::Udf(udf) => self.rewrite_udf(udf),
486 RexNode::FuncCall(function_call) => self.rewrite_function_call(function_call),
487 RexNode::Now(_) | RexNode::SecretRef(_) => {}
488 }
489 }
490
491 fn rewrite_udf(&self, udf: &mut UserDefinedFunction) {
492 udf.children
493 .iter_mut()
494 .for_each(|expr| self.rewrite_expr(expr));
495 }
496
497 fn rewrite_function_call(&self, function_call: &mut FunctionCall) {
498 function_call
499 .children
500 .iter_mut()
501 .for_each(|expr| self.rewrite_expr(expr));
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_alter_table_rename() {
511 let definition = "CREATE TABLE foo (a int, b int)";
512 let new_name = "bar";
513 let expected = "CREATE TABLE bar (a INT, b INT)";
514 let actual = alter_relation_rename(definition, new_name);
515 assert_eq!(expected, actual);
516 }
517
518 #[test]
519 fn test_rename_index_refs() {
520 let definition = "CREATE INDEX idx1 ON foo(v1 DESC, v2)";
521 let from = "foo";
522 let to = "bar";
523 let expected = "CREATE INDEX idx1 ON bar(v1 DESC, v2)";
524 let actual = alter_relation_rename_refs(definition, from, to);
525 assert_eq!(expected, actual);
526 }
527
528 #[test]
529 fn test_rename_sink_refs() {
530 let definition =
531 "CREATE SINK sink_t FROM foo WITH (connector = 'kafka', format = 'append_only')";
532 let from = "foo";
533 let to = "bar";
534 let expected =
535 "CREATE SINK sink_t FROM bar WITH (connector = 'kafka', format = 'append_only')";
536 let actual = alter_relation_rename_refs(definition, from, to);
537 assert_eq!(expected, actual);
538 }
539
540 #[test]
541 fn test_rename_with_alias_refs() {
542 let definition =
543 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM foo";
544 let from = "foo";
545 let to = "bar";
546 let expected =
547 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM bar AS foo";
548 let actual = alter_relation_rename_refs(definition, from, to);
549 assert_eq!(expected, actual);
550
551 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
552 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM bar AS foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
553 let actual = alter_relation_rename_refs(definition, from, to);
554 assert_eq!(expected, actual);
555
556 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM foo AS bar WHERE bar.v1 = 1";
557 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM bar AS bar WHERE bar.v1 = 1";
558 let actual = alter_relation_rename_refs(definition, from, to);
559 assert_eq!(expected, actual);
560 }
561
562 #[test]
563 fn test_rename_with_complex_funcs() {
564 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
565 agg1(\
566 foo.v1, func2(foo.v2) \
567 ORDER BY \
568 (SELECT foo.v3 FROM foo), \
569 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM foo) ORDER BY (SELECT foo.v6 FROM foo)) FROM foo)\
570 ) \
571 FROM foo";
572 let from = "foo";
573 let to = "bar";
574 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
575 agg1(\
576 foo.v1, func2(foo.v2) \
577 ORDER BY \
578 (SELECT foo.v3 FROM bar AS foo), \
579 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM bar AS foo) ORDER BY (SELECT foo.v6 FROM bar AS foo)) FROM bar AS foo)\
580 ) \
581 FROM bar AS foo";
582 let actual = alter_relation_rename_refs(definition, from, to);
583 assert_eq!(expected, actual);
584 }
585}