risingwave_meta/controller/
rename.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::expr::expr_node::{self, RexNode};
17use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction};
18use risingwave_pb::plan_common::PbColumnDesc;
19use risingwave_sqlparser::ast::{
20    Array, CdcTableInfo, CreateSink, CreateSinkStatement, CreateSourceStatement,
21    CreateSubscriptionStatement, Distinct, Expr, Function, FunctionArg, FunctionArgExpr,
22    FunctionArgList, Ident, ObjectName, Query, SelectItem, SetExpr, Statement, TableAlias,
23    TableFactor, TableWithJoins, Window,
24};
25use risingwave_sqlparser::parser::Parser;
26
27/// `alter_relation_rename` renames a relation to a new name in its `Create` statement, and returns
28/// the updated definition raw sql. Note that the `definition` must be a `Create` statement and the
29/// `new_name` must be a valid identifier, it should be validated before calling this function. To
30/// update all relations that depend on the renamed one, use `alter_relation_rename_refs`.
31pub fn alter_relation_rename(definition: &str, new_name: &str) -> String {
32    // This happens when we try to rename a table that's created by `CREATE TABLE AS`. Remove it
33    // when we support `SHOW CREATE TABLE` for `CREATE TABLE AS`.
34    if definition.is_empty() {
35        tracing::warn!("found empty definition when renaming relation, ignored.");
36        return definition.into();
37    }
38    let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
39    let mut stmt = ast
40        .into_iter()
41        .exactly_one()
42        .expect("should contains only one statement");
43
44    match &mut stmt {
45        Statement::CreateTable { name, .. }
46        | Statement::CreateView { name, .. }
47        | Statement::CreateIndex { name, .. }
48        | Statement::CreateSource {
49            stmt: CreateSourceStatement {
50                source_name: name, ..
51            },
52        }
53        | Statement::CreateSubscription {
54            stmt:
55                CreateSubscriptionStatement {
56                    subscription_name: name,
57                    ..
58                },
59        }
60        | Statement::CreateSink {
61            stmt: CreateSinkStatement {
62                sink_name: name, ..
63            },
64        } => replace_table_name(name, new_name),
65        _ => unreachable!(),
66    };
67
68    stmt.to_string()
69}
70
71/// `alter_relation_rename_refs` updates all references of renamed-relation in the definition of
72/// target relation's `Create` statement.
73pub fn alter_relation_rename_refs(definition: &str, from: &str, to: &str) -> String {
74    let ast = Parser::parse_sql(definition).expect("failed to parse relation definition");
75    let mut stmt = ast
76        .into_iter()
77        .exactly_one()
78        .expect("should contains only one statement");
79
80    match &mut stmt {
81        Statement::CreateTable {
82            query: Some(query), ..
83        }
84        | Statement::CreateView { query, .. }
85        | Statement::Query(query) // Used by view, actually we store a query as the definition of view.
86        | Statement::CreateSink {
87            stmt:
88            CreateSinkStatement {
89                sink_from: CreateSink::AsQuery(query),
90                into_table_name: None,
91                ..
92            },
93        } => {
94            QueryRewriter::rewrite_query(query, from, to);
95        }
96        Statement::CreateIndex { table_name, .. }
97        | Statement::CreateSink {
98            stmt:
99            CreateSinkStatement {
100                sink_from: CreateSink::From(table_name),
101                into_table_name: None,
102                ..
103            },
104        }| Statement::CreateSubscription {
105            stmt:
106            CreateSubscriptionStatement {
107                subscription_from: table_name,
108                ..
109            },
110        } | Statement::CreateTable {
111            cdc_table_info:
112            Some(CdcTableInfo {
113                source_name: table_name,
114                ..
115            }),
116            ..
117        } => replace_table_name(table_name, to),
118        Statement::CreateSink {
119            stmt: CreateSinkStatement {
120                sink_from,
121                into_table_name: Some(table_name),
122                ..
123            }
124        } => {
125            let idx = table_name.0.len() - 1;
126            if table_name.0[idx].real_value() == from {
127                table_name.0[idx] = Ident::from_real_value(to);
128            } else {
129                match sink_from {
130                    CreateSink::From(table_name) => replace_table_name(table_name, to),
131                    CreateSink::AsQuery(query) => QueryRewriter::rewrite_query(query, from, to),
132                }
133            }
134        }
135        _ => unreachable!(),
136    };
137    stmt.to_string()
138}
139
140/// Replace the last ident in the `table_name` with the given name, the object name is ensured to be
141/// non-empty. e.g. `schema.table` or `database.schema.table`.
142fn replace_table_name(table_name: &mut ObjectName, to: &str) {
143    let idx = table_name.0.len() - 1;
144    table_name.0[idx] = Ident::from_real_value(to);
145}
146
147/// `QueryRewriter` is a visitor that updates all references of relation named `from` to `to` in the
148/// given query, which is the part of create statement of `relation`.
149struct QueryRewriter<'a> {
150    from: &'a str,
151    to: &'a str,
152}
153
154impl QueryRewriter<'_> {
155    fn rewrite_query(query: &mut Query, from: &str, to: &str) {
156        let rewriter = QueryRewriter { from, to };
157        rewriter.visit_query(query)
158    }
159
160    /// Visit the query and update all references of relation named `from` to `to`.
161    fn visit_query(&self, query: &mut Query) {
162        if let Some(with) = &mut query.with {
163            for cte_table in &mut with.cte_tables {
164                match &mut cte_table.cte_inner {
165                    risingwave_sqlparser::ast::CteInner::Query(query) => self.visit_query(query),
166                    risingwave_sqlparser::ast::CteInner::ChangeLog(name) => {
167                        let idx = name.0.len() - 1;
168                        if name.0[idx].real_value() == self.from {
169                            replace_table_name(name, self.to);
170                        }
171                    }
172                }
173            }
174        }
175        self.visit_set_expr(&mut query.body);
176        for expr in &mut query.order_by {
177            self.visit_expr(&mut expr.expr);
178        }
179    }
180
181    /// Visit table factor and update all references of relation named `from` to `to`.
182    /// Rewrite idents(i.e. `schema.table`, `table`) that contains the old name in the
183    /// following pattern:
184    /// 1. `FROM a` to `FROM new_a AS a`
185    /// 2. `FROM a AS b` to `FROM new_a AS b`
186    ///
187    /// So that we DON'T have to:
188    /// 1. rewrite the select and expr part like `schema.table.column`, `table.column`,
189    ///    `alias.column` etc.
190    /// 2. handle the case that the old name is used as alias.
191    /// 3. handle the case that the new name is used as alias.
192    fn visit_table_factor(&self, table_factor: &mut TableFactor) {
193        match table_factor {
194            TableFactor::Table { name, alias, .. } => {
195                let idx = name.0.len() - 1;
196                if name.0[idx].real_value() == self.from {
197                    if alias.is_none() {
198                        *alias = Some(TableAlias {
199                            name: Ident::from_real_value(self.from),
200                            columns: vec![],
201                        });
202                    }
203                    name.0[idx] = Ident::from_real_value(self.to);
204                }
205            }
206            TableFactor::Derived { subquery, .. } => self.visit_query(subquery),
207            TableFactor::TableFunction { args, .. } => {
208                for arg in args {
209                    self.visit_function_arg(arg);
210                }
211            }
212            TableFactor::NestedJoin(table_with_joins) => {
213                self.visit_table_with_joins(table_with_joins);
214            }
215        }
216    }
217
218    /// Visit table with joins and update all references of relation named `from` to `to`.
219    fn visit_table_with_joins(&self, table_with_joins: &mut TableWithJoins) {
220        self.visit_table_factor(&mut table_with_joins.relation);
221        for join in &mut table_with_joins.joins {
222            self.visit_table_factor(&mut join.relation);
223        }
224    }
225
226    /// Visit query body expression and update all references.
227    fn visit_set_expr(&self, set_expr: &mut SetExpr) {
228        match set_expr {
229            SetExpr::Select(select) => {
230                if let Distinct::DistinctOn(exprs) = &mut select.distinct {
231                    for expr in exprs {
232                        self.visit_expr(expr);
233                    }
234                }
235                for select_item in &mut select.projection {
236                    self.visit_select_item(select_item);
237                }
238                for from_item in &mut select.from {
239                    self.visit_table_with_joins(from_item);
240                }
241                if let Some(where_clause) = &mut select.selection {
242                    self.visit_expr(where_clause);
243                }
244                for expr in &mut select.group_by {
245                    self.visit_expr(expr);
246                }
247                if let Some(having) = &mut select.having {
248                    self.visit_expr(having);
249                }
250                for named_window in &mut select.window {
251                    for expr in &mut named_window.window_spec.partition_by {
252                        self.visit_expr(expr);
253                    }
254                    for expr in &mut named_window.window_spec.order_by {
255                        self.visit_expr(&mut expr.expr);
256                    }
257                }
258            }
259            SetExpr::Query(query) => self.visit_query(query),
260            SetExpr::SetOperation { left, right, .. } => {
261                self.visit_set_expr(left);
262                self.visit_set_expr(right);
263            }
264            SetExpr::Values(_) => {}
265        }
266    }
267
268    /// Visit function arguments and update all references.
269    fn visit_function_arg(&self, function_arg: &mut FunctionArg) {
270        match function_arg {
271            FunctionArg::Unnamed(arg) | FunctionArg::Named { arg, .. } => match arg {
272                FunctionArgExpr::Expr(expr) | FunctionArgExpr::ExprQualifiedWildcard(expr, _) => {
273                    self.visit_expr(expr)
274                }
275                FunctionArgExpr::QualifiedWildcard(_, None) | FunctionArgExpr::Wildcard(None) => {}
276                FunctionArgExpr::QualifiedWildcard(_, Some(exprs))
277                | FunctionArgExpr::Wildcard(Some(exprs)) => {
278                    for expr in exprs {
279                        self.visit_expr(expr);
280                    }
281                }
282                FunctionArgExpr::SecretRef(_) => {}
283            },
284        }
285    }
286
287    fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) {
288        for arg in &mut arg_list.args {
289            self.visit_function_arg(arg);
290        }
291        for expr in &mut arg_list.order_by {
292            self.visit_expr(&mut expr.expr)
293        }
294    }
295
296    /// Visit function and update all references.
297    fn visit_function(&self, function: &mut Function) {
298        self.visit_function_arg_list(&mut function.arg_list);
299        if let Some(over) = &mut function.over {
300            match over {
301                Window::Spec(window) => {
302                    for expr in &mut window.partition_by {
303                        self.visit_expr(expr);
304                    }
305                    for expr in &mut window.order_by {
306                        self.visit_expr(&mut expr.expr);
307                    }
308                }
309                Window::Name(_) => {
310                    // Named window references don't contain expressions to rewrite
311                }
312            }
313        }
314    }
315
316    /// Visit expression and update all references.
317    fn visit_expr(&self, expr: &mut Expr) {
318        match expr {
319            Expr::FieldIdentifier(expr, ..)
320            | Expr::IsNull(expr)
321            | Expr::IsNotNull(expr)
322            | Expr::IsTrue(expr)
323            | Expr::IsNotTrue(expr)
324            | Expr::IsFalse(expr)
325            | Expr::IsNotFalse(expr)
326            | Expr::IsUnknown(expr)
327            | Expr::IsNotUnknown(expr)
328            | Expr::IsJson { expr, .. }
329            | Expr::InList { expr, .. }
330            | Expr::SomeOp(expr)
331            | Expr::AllOp(expr)
332            | Expr::UnaryOp { expr, .. }
333            | Expr::Cast { expr, .. }
334            | Expr::TryCast { expr, .. }
335            | Expr::AtTimeZone {
336                timestamp: expr, ..
337            }
338            | Expr::Extract { expr, .. }
339            | Expr::Substring { expr, .. }
340            | Expr::Overlay { expr, .. }
341            | Expr::Trim { expr, .. }
342            | Expr::Nested(expr)
343            | Expr::Index { obj: expr, .. }
344            | Expr::ArrayRangeIndex { obj: expr, .. } => self.visit_expr(expr),
345
346            Expr::Position { substring, string } => {
347                self.visit_expr(substring);
348                self.visit_expr(string);
349            }
350
351            Expr::InSubquery { expr, subquery, .. } => {
352                self.visit_expr(expr);
353                self.visit_query(subquery);
354            }
355            Expr::Between {
356                expr, low, high, ..
357            } => {
358                self.visit_expr(expr);
359                self.visit_expr(low);
360                self.visit_expr(high);
361            }
362            Expr::Like {
363                expr, pattern: pat, ..
364            } => {
365                self.visit_expr(expr);
366                self.visit_expr(pat);
367            }
368            Expr::ILike {
369                expr, pattern: pat, ..
370            } => {
371                self.visit_expr(expr);
372                self.visit_expr(pat);
373            }
374            Expr::SimilarTo {
375                expr, pattern: pat, ..
376            } => {
377                self.visit_expr(expr);
378                self.visit_expr(pat);
379            }
380
381            Expr::IsDistinctFrom(expr1, expr2)
382            | Expr::IsNotDistinctFrom(expr1, expr2)
383            | Expr::BinaryOp {
384                left: expr1,
385                right: expr2,
386                ..
387            } => {
388                self.visit_expr(expr1);
389                self.visit_expr(expr2);
390            }
391            Expr::Function(function) => self.visit_function(function),
392            Expr::Exists(query) | Expr::Subquery(query) | Expr::ArraySubquery(query) => {
393                self.visit_query(query)
394            }
395
396            Expr::GroupingSets(exprs_vec) | Expr::Cube(exprs_vec) | Expr::Rollup(exprs_vec) => {
397                for exprs in exprs_vec {
398                    for expr in exprs {
399                        self.visit_expr(expr);
400                    }
401                }
402            }
403
404            Expr::Row(exprs) | Expr::Array(Array { elem: exprs, .. }) => {
405                for expr in exprs {
406                    self.visit_expr(expr);
407                }
408            }
409            Expr::Map { entries } => {
410                for (key, value) in entries {
411                    self.visit_expr(key);
412                    self.visit_expr(value);
413                }
414            }
415
416            Expr::LambdaFunction { body, args: _ } => self.visit_expr(body),
417
418            // No need to visit.
419            Expr::Identifier(_)
420            | Expr::CompoundIdentifier(_)
421            | Expr::Collate { .. }
422            | Expr::Value(_)
423            | Expr::Parameter { .. }
424            | Expr::TypedString { .. }
425            | Expr::Case { .. } => {}
426        }
427    }
428
429    /// Visit select item and update all references.
430    fn visit_select_item(&self, select_item: &mut SelectItem) {
431        match select_item {
432            SelectItem::UnnamedExpr(expr)
433            | SelectItem::ExprQualifiedWildcard(expr, _)
434            | SelectItem::ExprWithAlias { expr, .. } => self.visit_expr(expr),
435            SelectItem::QualifiedWildcard(_, None) | SelectItem::Wildcard(None) => {}
436            SelectItem::QualifiedWildcard(_, Some(exprs)) | SelectItem::Wildcard(Some(exprs)) => {
437                for expr in exprs {
438                    self.visit_expr(expr);
439                }
440            }
441        }
442    }
443}
444
445/// Rewrite the expression in index item after there's a schema change on the primary table.
446// TODO: move this out of `rename.rs`, this has nothing to do with renaming.
447pub struct IndexItemRewriter {
448    pub original_columns: Vec<PbColumnDesc>,
449    pub new_columns: Vec<PbColumnDesc>,
450}
451
452impl IndexItemRewriter {
453    pub fn rewrite_expr(&self, expr: &mut ExprNode) {
454        let rex_node = expr.rex_node.as_mut().unwrap();
455        match rex_node {
456            RexNode::InputRef(idx) => {
457                let old_idx = *idx as usize;
458                let original_column = &self.original_columns[old_idx];
459                let (new_idx, new_column) = self
460                    .new_columns
461                    .iter()
462                    .find_position(|c| c.column_id == original_column.column_id)
463                    .expect("should already checked index referencing column still exists");
464                *idx = new_idx as u32;
465
466                // If there's a type change, we need to wrap it with an internal `CompositeCast` to
467                // maintain the correct return type. It cannot execute and will be eliminated in
468                // the frontend when rebuilding the index items.
469                if new_column.column_type != original_column.column_type {
470                    let old_type = original_column.column_type.clone().unwrap();
471                    let new_type = new_column.column_type.clone().unwrap();
472
473                    assert_eq!(&old_type, expr.return_type.as_ref().unwrap());
474                    expr.return_type = Some(new_type); // update return type of `InputRef`
475
476                    let new_expr_node = ExprNode {
477                        function_type: expr_node::Type::CompositeCast as _,
478                        return_type: Some(old_type),
479                        rex_node: RexNode::FuncCall(FunctionCall {
480                            children: vec![expr.clone()],
481                        })
482                        .into(),
483                    };
484
485                    *expr = new_expr_node;
486                }
487            }
488            RexNode::Constant(_) => {}
489            RexNode::Udf(udf) => self.rewrite_udf(udf),
490            RexNode::FuncCall(function_call) => self.rewrite_function_call(function_call),
491            RexNode::Now(_) | RexNode::SecretRef(_) => {}
492        }
493    }
494
495    fn rewrite_udf(&self, udf: &mut UserDefinedFunction) {
496        udf.children
497            .iter_mut()
498            .for_each(|expr| self.rewrite_expr(expr));
499    }
500
501    fn rewrite_function_call(&self, function_call: &mut FunctionCall) {
502        function_call
503            .children
504            .iter_mut()
505            .for_each(|expr| self.rewrite_expr(expr));
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_alter_table_rename() {
515        let definition = "CREATE TABLE foo (a int, b int)";
516        let new_name = "bar";
517        let expected = "CREATE TABLE bar (a INT, b INT)";
518        let actual = alter_relation_rename(definition, new_name);
519        assert_eq!(expected, actual);
520    }
521
522    #[test]
523    fn test_rename_index_refs() {
524        let definition = "CREATE INDEX idx1 ON foo(v1 DESC, v2)";
525        let from = "foo";
526        let to = "bar";
527        let expected = "CREATE INDEX idx1 ON bar(v1 DESC, v2)";
528        let actual = alter_relation_rename_refs(definition, from, to);
529        assert_eq!(expected, actual);
530    }
531
532    #[test]
533    fn test_rename_sink_refs() {
534        let definition =
535            "CREATE SINK sink_t FROM foo WITH (connector = 'kafka', format = 'append_only')";
536        let from = "foo";
537        let to = "bar";
538        let expected =
539            "CREATE SINK sink_t FROM bar WITH (connector = 'kafka', format = 'append_only')";
540        let actual = alter_relation_rename_refs(definition, from, to);
541        assert_eq!(expected, actual);
542    }
543
544    #[test]
545    fn test_rename_with_alias_refs() {
546        let definition =
547            "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM foo";
548        let from = "foo";
549        let to = "bar";
550        let expected =
551            "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, foo.v2 AS m2v FROM bar AS foo";
552        let actual = alter_relation_rename_refs(definition, from, to);
553        assert_eq!(expected, actual);
554
555        let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
556        let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT foo.v1 AS m1v, (foo.v2).v3 AS m2v FROM bar AS foo WHERE foo.v1 = 1 AND (foo.v2).v3 IS TRUE";
557        let actual = alter_relation_rename_refs(definition, from, to);
558        assert_eq!(expected, actual);
559
560        let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM foo AS bar WHERE bar.v1 = 1";
561        let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT bar.v1 AS m1v, (bar.v2).v3 AS m2v FROM bar AS bar WHERE bar.v1 = 1";
562        let actual = alter_relation_rename_refs(definition, from, to);
563        assert_eq!(expected, actual);
564    }
565
566    #[test]
567    fn test_rename_with_complex_funcs() {
568        let definition = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
569                            agg1(\
570                              foo.v1, func2(foo.v2) \
571                              ORDER BY \
572                              (SELECT foo.v3 FROM foo), \
573                              (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM foo) ORDER BY (SELECT foo.v6 FROM foo)) FROM foo)\
574                            ) \
575                          FROM foo";
576        let from = "foo";
577        let to = "bar";
578        let expected = "CREATE MATERIALIZED VIEW mv1 AS SELECT \
579                          agg1(\
580                            foo.v1, func2(foo.v2) \
581                            ORDER BY \
582                            (SELECT foo.v3 FROM bar AS foo), \
583                            (SELECT first_value(foo.v4) OVER (PARTITION BY (SELECT foo.v5 FROM bar AS foo) ORDER BY (SELECT foo.v6 FROM bar AS foo)) FROM bar AS foo)\
584                          ) \
585                        FROM bar AS foo";
586        let actual = alter_relation_rename_refs(definition, from, to);
587        assert_eq!(expected, actual);
588    }
589}