1use itertools::Itertools;
16use risingwave_pb::expr::expr_node::{self, RexNode};
17use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction};
18use risingwave_pb::plan_common::PbColumnDesc;
19use risingwave_sqlparser::ast::{
20 Array, CdcTableInfo, CreateSink, CreateSinkStatement, CreateSourceStatement,
21 CreateSubscriptionStatement, Distinct, Expr, Function, FunctionArg, FunctionArgExpr,
22 FunctionArgList, Ident, ObjectName, Query, SelectItem, SetExpr, Statement, TableAlias,
23 TableFactor, TableWithJoins, Window,
24};
25use risingwave_sqlparser::parser::Parser;
26
27pub fn alter_relation_rename(definition: &str, new_name: &str) -> String {
32 if definition.is_empty() {
35 tracing::warn!("found empty definition when renaming relation, ignored.");
36 return definition.into();
37 }
38 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
39 let mut stmt = ast
40 .into_iter()
41 .exactly_one()
42 .expect("should contains only one statement");
43
44 match &mut stmt {
45 Statement::CreateTable { name, .. }
46 | Statement::CreateView { name, .. }
47 | Statement::CreateIndex { name, .. }
48 | Statement::CreateSource {
49 stmt: CreateSourceStatement {
50 source_name: name, ..
51 },
52 }
53 | Statement::CreateSubscription {
54 stmt:
55 CreateSubscriptionStatement {
56 subscription_name: name,
57 ..
58 },
59 }
60 | Statement::CreateSink {
61 stmt: CreateSinkStatement {
62 sink_name: name, ..
63 },
64 } => replace_table_name(name, new_name),
65 _ => unreachable!(),
66 };
67
68 stmt.to_string()
69}
70
71pub fn alter_relation_rename_refs(definition: &str, from: &str, to: &str) -> String {
74 let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
75 let mut stmt = ast
76 .into_iter()
77 .exactly_one()
78 .expect("should contains only one statement");
79
80 match &mut stmt {
81 Statement::CreateTable {
82 query: Some(query), ..
83 }
84 | Statement::CreateView { query, .. }
85 | Statement::Query(query) | Statement::CreateSink {
87 stmt:
88 CreateSinkStatement {
89 sink_from: CreateSink::AsQuery(query),
90 into_table_name: None,
91 ..
92 },
93 } => {
94 QueryRewriter::rewrite_query(query, from, to);
95 }
96 Statement::CreateIndex { table_name, .. }
97 | Statement::CreateSink {
98 stmt:
99 CreateSinkStatement {
100 sink_from: CreateSink::From(table_name),
101 into_table_name: None,
102 ..
103 },
104 }| Statement::CreateSubscription {
105 stmt:
106 CreateSubscriptionStatement {
107 subscription_from: table_name,
108 ..
109 },
110 } | Statement::CreateTable {
111 cdc_table_info:
112 Some(CdcTableInfo {
113 source_name: table_name,
114 ..
115 }),
116 ..
117 } => replace_table_name(table_name, to),
118 Statement::CreateSink {
119 stmt: CreateSinkStatement {
120 sink_from,
121 into_table_name: Some(table_name),
122 ..
123 }
124 } => {
125 let idx = table_name.0.len() - 1;
126 if table_name.0[idx].real_value() == from {
127 table_name.0[idx] = Ident::from_real_value(to);
128 } else {
129 match sink_from {
130 CreateSink::From(table_name) => replace_table_name(table_name, to),
131 CreateSink::AsQuery(query) => QueryRewriter::rewrite_query(query, from, to),
132 }
133 }
134 }
135 _ => unreachable!(),
136 };
137 stmt.to_string()
138}
139
140fn replace_table_name(table_name: &mut ObjectName, to: &str) {
143 let idx = table_name.0.len() - 1;
144 table_name.0[idx] = Ident::from_real_value(to);
145}
146
147struct QueryRewriter<'a> {
150 from: &'a str,
151 to: &'a str,
152}
153
154impl QueryRewriter<'_> {
155 fn rewrite_query(query: &mut Query, from: &str, to: &str) {
156 let rewriter = QueryRewriter { from, to };
157 rewriter.visit_query(query)
158 }
159
160 fn visit_query(&self, query: &mut Query) {
162 if let Some(with) = &mut query.with {
163 for cte_table in &mut with.cte_tables {
164 match &mut cte_table.cte_inner {
165 risingwave_sqlparser::ast::CteInner::Query(query) => self.visit_query(query),
166 risingwave_sqlparser::ast::CteInner::ChangeLog(name) => {
167 let idx = name.0.len() - 1;
168 if name.0[idx].real_value() == self.from {
169 replace_table_name(name, self.to);
170 }
171 }
172 }
173 }
174 }
175 self.visit_set_expr(&mut query.body);
176 for expr in &mut query.order_by {
177 self.visit_expr(&mut expr.expr);
178 }
179 }
180
181 fn visit_table_factor(&self, table_factor: &mut TableFactor) {
193 match table_factor {
194 TableFactor::Table { name, alias, .. } => {
195 let idx = name.0.len() - 1;
196 if name.0[idx].real_value() == self.from {
197 if alias.is_none() {
198 *alias = Some(TableAlias {
199 name: Ident::from_real_value(self.from),
200 columns: vec![],
201 });
202 }
203 name.0[idx] = Ident::from_real_value(self.to);
204 }
205 }
206 TableFactor::Derived { subquery, .. } => self.visit_query(subquery),
207 TableFactor::TableFunction { args, .. } => {
208 for arg in args {
209 self.visit_function_arg(arg);
210 }
211 }
212 TableFactor::NestedJoin(table_with_joins) => {
213 self.visit_table_with_joins(table_with_joins);
214 }
215 }
216 }
217
218 fn visit_table_with_joins(&self, table_with_joins: &mut TableWithJoins) {
220 self.visit_table_factor(&mut table_with_joins.relation);
221 for join in &mut table_with_joins.joins {
222 self.visit_table_factor(&mut join.relation);
223 }
224 }
225
226 fn visit_set_expr(&self, set_expr: &mut SetExpr) {
228 match set_expr {
229 SetExpr::Select(select) => {
230 if let Distinct::DistinctOn(exprs) = &mut select.distinct {
231 for expr in exprs {
232 self.visit_expr(expr);
233 }
234 }
235 for select_item in &mut select.projection {
236 self.visit_select_item(select_item);
237 }
238 for from_item in &mut select.from {
239 self.visit_table_with_joins(from_item);
240 }
241 if let Some(where_clause) = &mut select.selection {
242 self.visit_expr(where_clause);
243 }
244 for expr in &mut select.group_by {
245 self.visit_expr(expr);
246 }
247 if let Some(having) = &mut select.having {
248 self.visit_expr(having);
249 }
250 for named_window in &mut select.window {
251 for expr in &mut named_window.window_spec.partition_by {
252 self.visit_expr(expr);
253 }
254 for expr in &mut named_window.window_spec.order_by {
255 self.visit_expr(&mut expr.expr);
256 }
257 }
258 }
259 SetExpr::Query(query) => self.visit_query(query),
260 SetExpr::SetOperation { left, right, .. } => {
261 self.visit_set_expr(left);
262 self.visit_set_expr(right);
263 }
264 SetExpr::Values(_) => {}
265 }
266 }
267
268 fn visit_function_arg(&self, function_arg: &mut FunctionArg) {
270 match function_arg {
271 FunctionArg::Unnamed(arg) | FunctionArg::Named { arg, .. } => match arg {
272 FunctionArgExpr::Expr(expr) | FunctionArgExpr::ExprQualifiedWildcard(expr, _) => {
273 self.visit_expr(expr)
274 }
275 FunctionArgExpr::QualifiedWildcard(_, None) | FunctionArgExpr::Wildcard(None) => {}
276 FunctionArgExpr::QualifiedWildcard(_, Some(exprs))
277 | FunctionArgExpr::Wildcard(Some(exprs)) => {
278 for expr in exprs {
279 self.visit_expr(expr);
280 }
281 }
282 },
283 }
284 }
285
286 fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) {
287 for arg in &mut arg_list.args {
288 self.visit_function_arg(arg);
289 }
290 for expr in &mut arg_list.order_by {
291 self.visit_expr(&mut expr.expr)
292 }
293 }
294
295 fn visit_function(&self, function: &mut Function) {
297 self.visit_function_arg_list(&mut function.arg_list);
298 if let Some(over) = &mut function.over {
299 match over {
300 Window::Spec(window) => {
301 for expr in &mut window.partition_by {
302 self.visit_expr(expr);
303 }
304 for expr in &mut window.order_by {
305 self.visit_expr(&mut expr.expr);
306 }
307 }
308 Window::Name(_) => {
309 }
311 }
312 }
313 }
314
315 fn visit_expr(&self, expr: &mut Expr) {
317 match expr {
318 Expr::FieldIdentifier(expr, ..)
319 | Expr::IsNull(expr)
320 | Expr::IsNotNull(expr)
321 | Expr::IsTrue(expr)
322 | Expr::IsNotTrue(expr)
323 | Expr::IsFalse(expr)
324 | Expr::IsNotFalse(expr)
325 | Expr::IsUnknown(expr)
326 | Expr::IsNotUnknown(expr)
327 | Expr::IsJson { expr, .. }
328 | Expr::InList { expr, .. }
329 | Expr::SomeOp(expr)
330 | Expr::AllOp(expr)
331 | Expr::UnaryOp { expr, .. }
332 | Expr::Cast { expr, .. }
333 | Expr::TryCast { expr, .. }
334 | Expr::AtTimeZone {
335 timestamp: expr, ..
336 }
337 | Expr::Extract { expr, .. }
338 | Expr::Substring { expr, .. }
339 | Expr::Overlay { expr, .. }
340 | Expr::Trim { expr, .. }
341 | Expr::Nested(expr)
342 | Expr::Index { obj: expr, .. }
343 | Expr::ArrayRangeIndex { obj: expr, .. } => self.visit_expr(expr),
344
345 Expr::Position { substring, string } => {
346 self.visit_expr(substring);
347 self.visit_expr(string);
348 }
349
350 Expr::InSubquery { expr, subquery, .. } => {
351 self.visit_expr(expr);
352 self.visit_query(subquery);
353 }
354 Expr::Between {
355 expr, low, high, ..
356 } => {
357 self.visit_expr(expr);
358 self.visit_expr(low);
359 self.visit_expr(high);
360 }
361 Expr::Like {
362 expr, pattern: pat, ..
363 } => {
364 self.visit_expr(expr);
365 self.visit_expr(pat);
366 }
367 Expr::ILike {
368 expr, pattern: pat, ..
369 } => {
370 self.visit_expr(expr);
371 self.visit_expr(pat);
372 }
373 Expr::SimilarTo {
374 expr, pattern: pat, ..
375 } => {
376 self.visit_expr(expr);
377 self.visit_expr(pat);
378 }
379
380 Expr::IsDistinctFrom(expr1, expr2)
381 | Expr::IsNotDistinctFrom(expr1, expr2)
382 | Expr::BinaryOp {
383 left: expr1,
384 right: expr2,
385 ..
386 } => {
387 self.visit_expr(expr1);
388 self.visit_expr(expr2);
389 }
390 Expr::Function(function) => self.visit_function(function),
391 Expr::Exists(query) | Expr::Subquery(query) | Expr::ArraySubquery(query) => {
392 self.visit_query(query)
393 }
394
395 Expr::GroupingSets(exprs_vec) | Expr::Cube(exprs_vec) | Expr::Rollup(exprs_vec) => {
396 for exprs in exprs_vec {
397 for expr in exprs {
398 self.visit_expr(expr);
399 }
400 }
401 }
402
403 Expr::Row(exprs) | Expr::Array(Array { elem: exprs, .. }) => {
404 for expr in exprs {
405 self.visit_expr(expr);
406 }
407 }
408 Expr::Map { entries } => {
409 for (key, value) in entries {
410 self.visit_expr(key);
411 self.visit_expr(value);
412 }
413 }
414
415 Expr::LambdaFunction { body, args: _ } => self.visit_expr(body),
416
417 Expr::Identifier(_)
419 | Expr::CompoundIdentifier(_)
420 | Expr::Collate { .. }
421 | Expr::Value(_)
422 | Expr::Parameter { .. }
423 | Expr::TypedString { .. }
424 | Expr::Case { .. } => {}
425 }
426 }
427
428 fn visit_select_item(&self, select_item: &mut SelectItem) {
430 match select_item {
431 SelectItem::UnnamedExpr(expr)
432 | SelectItem::ExprQualifiedWildcard(expr, _)
433 | SelectItem::ExprWithAlias { expr, .. } => self.visit_expr(expr),
434 SelectItem::QualifiedWildcard(_, None) | SelectItem::Wildcard(None) => {}
435 SelectItem::QualifiedWildcard(_, Some(exprs)) | SelectItem::Wildcard(Some(exprs)) => {
436 for expr in exprs {
437 self.visit_expr(expr);
438 }
439 }
440 }
441 }
442}
443
444pub struct IndexItemRewriter {
447 pub original_columns: Vec<PbColumnDesc>,
448 pub new_columns: Vec<PbColumnDesc>,
449}
450
451impl IndexItemRewriter {
452 pub fn rewrite_expr(&self, expr: &mut ExprNode) {
453 let rex_node = expr.rex_node.as_mut().unwrap();
454 match rex_node {
455 RexNode::InputRef(idx) => {
456 let old_idx = *idx as usize;
457 let original_column = &self.original_columns[old_idx];
458 let (new_idx, new_column) = self
459 .new_columns
460 .iter()
461 .find_position(|c| c.column_id == original_column.column_id)
462 .expect("should already checked index referencing column still exists");
463 *idx = new_idx as u32;
464
465 if new_column.column_type != original_column.column_type {
469 let old_type = original_column.column_type.clone().unwrap();
470 let new_type = new_column.column_type.clone().unwrap();
471
472 assert_eq!(&old_type, expr.return_type.as_ref().unwrap());
473 expr.return_type = Some(new_type); let new_expr_node = ExprNode {
476 function_type: expr_node::Type::CompositeCast as _,
477 return_type: Some(old_type),
478 rex_node: RexNode::FuncCall(FunctionCall {
479 children: vec![expr.clone()],
480 })
481 .into(),
482 };
483
484 *expr = new_expr_node;
485 }
486 }
487 RexNode::Constant(_) => {}
488 RexNode::Udf(udf) => self.rewrite_udf(udf),
489 RexNode::FuncCall(function_call) => self.rewrite_function_call(function_call),
490 RexNode::Now(_) => {}
491 }
492 }
493
494 fn rewrite_udf(&self, udf: &mut UserDefinedFunction) {
495 udf.children
496 .iter_mut()
497 .for_each(|expr| self.rewrite_expr(expr));
498 }
499
500 fn rewrite_function_call(&self, function_call: &mut FunctionCall) {
501 function_call
502 .children
503 .iter_mut()
504 .for_each(|expr| self.rewrite_expr(expr));
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_alter_table_rename() {
514 let definition = "CREATE TABLE foo (a int, b int)";
515 let new_name = "bar";
516 let expected = "CREATE TABLE bar (a INT, b INT)";
517 let actual = alter_relation_rename(definition, new_name);
518 assert_eq!(expected, actual);
519 }
520
521 #[test]
522 fn test_rename_index_refs() {
523 let definition = "CREATE INDEX idx1 ON foo(v1 DESC, v2)";
524 let from = "foo";
525 let to = "bar";
526 let expected = "CREATE INDEX idx1 ON bar(v1 DESC, v2)";
527 let actual = alter_relation_rename_refs(definition, from, to);
528 assert_eq!(expected, actual);
529 }
530
531 #[test]
532 fn test_rename_sink_refs() {
533 let definition =
534 "CREATE SINK sink_t FROM foo WITH (connector = 'kafka', format = 'append_only')";
535 let from = "foo";
536 let to = "bar";
537 let expected =
538 "CREATE SINK sink_t FROM bar WITH (connector = 'kafka', format = 'append_only')";
539 let actual = alter_relation_rename_refs(definition, from, to);
540 assert_eq!(expected, actual);
541 }
542
543 #[test]
544 fn test_rename_with_alias_refs() {
545 let definition =
546 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM foo";
547 let from = "foo";
548 let to = "bar";
549 let expected =
550 "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM bar AS foo";
551 let actual = alter_relation_rename_refs(definition, from, to);
552 assert_eq!(expected, actual);
553
554 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
555 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM bar AS foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
556 let actual = alter_relation_rename_refs(definition, from, to);
557 assert_eq!(expected, actual);
558
559 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM foo AS bar WHERE bar.v1 = 1";
560 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM bar AS bar WHERE bar.v1 = 1";
561 let actual = alter_relation_rename_refs(definition, from, to);
562 assert_eq!(expected, actual);
563 }
564
565 #[test]
566 fn test_rename_with_complex_funcs() {
567 let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
568 agg1(\
569 foo.v1, func2(foo.v2) \
570 ORDER BY \
571 (SELECT foo.v3 FROM foo), \
572 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM foo) ORDER BY (SELECT foo.v6 FROM foo)) FROM foo)\
573 ) \
574 FROM foo";
575 let from = "foo";
576 let to = "bar";
577 let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
578 agg1(\
579 foo.v1, func2(foo.v2) \
580 ORDER BY \
581 (SELECT foo.v3 FROM bar AS foo), \
582 (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM bar AS foo) ORDER BY (SELECT foo.v6 FROM bar AS foo)) FROM bar AS foo)\
583 ) \
584 FROM bar AS foo";
585 let actual = alter_relation_rename_refs(definition, from, to);
586 assert_eq!(expected, actual);
587 }
588}