risingwave_frontend/catalog/
purify.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Context;
16use itertools::Itertools;
17use prost::Message as _;
18use risingwave_common::bail;
19use risingwave_common::catalog::{ColumnCatalog, ColumnId};
20use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn;
21use risingwave_sqlparser::ast::*;
22
23use crate::error::Result;
24use crate::utils::data_type::DataTypeToAst as _;
25
26mod pk_column {
27    use super::*;
28    // Identifies a primary key column...
29    pub trait PkColumn {
30        fn is(&self, column: &ColumnCatalog) -> bool;
31    }
32    // ...by column name.
33    impl PkColumn for &str {
34        fn is(&self, column: &ColumnCatalog) -> bool {
35            column.name() == *self
36        }
37    }
38    // ...by column ID.
39    impl PkColumn for ColumnId {
40        fn is(&self, column: &ColumnCatalog) -> bool {
41            column.column_id() == *self
42        }
43    }
44}
45use pk_column::PkColumn;
46
47/// Try to restore missing column definitions and constraints in the persisted table (or source)
48/// definition, if the schema is derived from external systems (like schema registry) or it's
49/// created by `CREATE TABLE AS`.
50///
51/// Returns error if restoring failed, or the persisted definition is invalid.
52pub fn try_purify_table_source_create_sql_ast(
53    mut base: Statement,
54    columns: &[ColumnCatalog],
55    row_id_index: Option<usize>,
56    pk_column_ids: &[impl PkColumn],
57) -> Result<Statement> {
58    let (Statement::CreateTable {
59        columns: column_defs,
60        constraints,
61        wildcard_idx,
62        ..
63    }
64    | Statement::CreateSource {
65        stmt:
66            CreateSourceStatement {
67                columns: column_defs,
68                constraints,
69                wildcard_idx,
70                ..
71            },
72    }) = &mut base
73    else {
74        bail!("expect `CREATE TABLE` or `CREATE SOURCE` statement, found: `{base:?}`");
75    };
76
77    // First, remove the wildcard from the definition.
78    *wildcard_idx = None;
79
80    // Filter out columns that are not defined by users in SQL.
81    let defined_columns = columns.iter().filter(|c| c.is_user_defined());
82
83    // Derive `ColumnDef` from `ColumnCatalog`.
84    let mut purified_column_defs = Vec::new();
85    for column in defined_columns {
86        let mut column_def = if let Some(existing) = column_defs
87            .iter()
88            .find(|c| c.name.real_value() == column.name())
89        {
90            // If the column is already defined in the persisted definition, retrieve it.
91            existing.clone()
92        } else {
93            assert!(
94                !column.is_generated(),
95                "generated column must not be inferred"
96            );
97
98            // Generate a new `ColumnDef` from the catalog.
99            ColumnDef {
100                name: column.name().into(),
101                data_type: Some(column.data_type().to_ast()),
102                collation: None,
103                options: Vec::new(), // pk will be specified with table constraints
104            }
105        };
106
107        // Fill in the persisted default value desc.
108        if let Some(c) = &column.column_desc.generated_or_default_column
109            && let GeneratedOrDefaultColumn::DefaultColumn(desc) = c
110        {
111            let persisted = desc.encode_to_vec().into_boxed_slice();
112
113            let default_value_option = column_def
114                .options
115                .extract_if(.., |o| {
116                    matches!(
117                        o.option,
118                        ColumnOption::DefaultValue { .. }
119                            | ColumnOption::DefaultValueInternal { .. }
120                    )
121                })
122                .at_most_one()
123                .ok()
124                .context("multiple default value options found")?;
125
126            let expr = default_value_option.and_then(|o| match o.option {
127                ColumnOption::DefaultValue(expr) => Some(expr),
128                ColumnOption::DefaultValueInternal { expr, .. } => expr,
129                _ => unreachable!(),
130            });
131
132            column_def.options.push(ColumnOptionDef {
133                name: None,
134                option: ColumnOption::DefaultValueInternal { persisted, expr },
135            });
136        }
137
138        purified_column_defs.push(column_def);
139    }
140    *column_defs = purified_column_defs;
141
142    // Specify user-defined primary key in table constraints.
143    let has_pk_column_constraint = column_defs.iter().any(|c| {
144        c.options
145            .iter()
146            .any(|o| matches!(o.option, ColumnOption::Unique { is_primary: true }))
147    });
148    if !has_pk_column_constraint && row_id_index.is_none() {
149        let mut pk_columns = Vec::new();
150
151        for id in pk_column_ids {
152            let column = columns
153                .iter()
154                .find(|c| id.is(c))
155                .context("primary key column not found")?;
156            if !column.is_user_defined() {
157                bail /* unlikely */ !(
158                    "primary key column \"{}\" is not user-defined",
159                    column.name()
160                );
161            }
162            pk_columns.push(column.name().into());
163        }
164
165        let pk_constraint = TableConstraint::Unique {
166            name: None,
167            columns: pk_columns,
168            is_primary: true,
169        };
170
171        // We don't support table constraints other than `PRIMARY KEY`, thus simply overwrite.
172        assert!(
173            constraints.len() <= 1
174                && constraints.iter().all(|c| matches!(
175                    c,
176                    TableConstraint::Unique {
177                        is_primary: true,
178                        ..
179                    }
180                )),
181            "unexpected table constraints: {constraints:?}",
182        );
183
184        *constraints = vec![pk_constraint];
185    }
186
187    Ok(base)
188}