risingwave_frontend/handler/
alter_table_column.rs1use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use itertools::Itertools;
19use pgwire::pg_response::{PgResponse, StatementType};
20use risingwave_common::catalog::ColumnCatalog;
21use risingwave_common::hash::VnodeCount;
22use risingwave_common::{bail, bail_not_implemented};
23use risingwave_connector::sink::catalog::SinkCatalog;
24use risingwave_pb::ddl_service::TableJobType;
25use risingwave_pb::stream_plan::stream_node::PbNodeBody;
26use risingwave_pb::stream_plan::{ProjectNode, StreamFragmentGraph};
27use risingwave_sqlparser::ast::{
28 AlterColumnOperation, AlterTableOperation, ColumnOption, ObjectName, Statement,
29};
30
31use super::create_source::SqlColumnStrategy;
32use super::create_table::{ColumnIdGenerator, generate_stream_graph_for_replace_table};
33use super::{HandlerArgs, RwPgResponse};
34use crate::catalog::purify::try_purify_table_source_create_sql_ast;
35use crate::catalog::root_catalog::SchemaPath;
36use crate::catalog::source_catalog::SourceCatalog;
37use crate::catalog::table_catalog::TableType;
38use crate::error::{ErrorCode, Result, RwError};
39use crate::expr::{Expr, ExprImpl, InputRef};
40use crate::handler::create_sink::{fetch_incoming_sinks, insert_merger_to_union_with_project};
41use crate::session::SessionImpl;
42use crate::{Binder, TableCatalog};
43
44pub async fn get_new_table_definition_for_cdc_table(
46 session: &Arc<SessionImpl>,
47 table_name: ObjectName,
48 new_columns: &[ColumnCatalog],
49) -> Result<(Statement, Arc<TableCatalog>)> {
50 let original_catalog = fetch_table_catalog_for_alter(session.as_ref(), &table_name)?;
51
52 assert_eq!(
53 original_catalog.row_id_index, None,
54 "primary key of cdc table must be user defined"
55 );
56
57 let mut definition = original_catalog.create_sql_ast()?;
59
60 {
63 let Statement::CreateTable {
64 columns,
65 constraints,
66 ..
67 } = &mut definition
68 else {
69 panic!("unexpected statement: {:?}", definition);
70 };
71
72 columns.clear();
73 constraints.clear();
74 }
75
76 let new_definition = try_purify_table_source_create_sql_ast(
77 definition,
78 new_columns,
79 None,
80 &original_catalog.pk_column_names(),
83 )?;
84
85 Ok((new_definition, original_catalog))
86}
87
88pub async fn get_replace_table_plan(
89 session: &Arc<SessionImpl>,
90 table_name: ObjectName,
91 new_definition: Statement,
92 old_catalog: &Arc<TableCatalog>,
93 sql_column_strategy: SqlColumnStrategy,
94) -> Result<(
95 Option<SourceCatalog>,
96 TableCatalog,
97 StreamFragmentGraph,
98 TableJobType,
99)> {
100 let handler_args = HandlerArgs::new(session.clone(), &new_definition, Arc::from(""))?;
102 let col_id_gen = ColumnIdGenerator::new_alter(old_catalog);
103
104 let (mut graph, table, source, job_type) = generate_stream_graph_for_replace_table(
105 session,
106 table_name,
107 old_catalog,
108 handler_args.clone(),
109 new_definition,
110 col_id_gen,
111 sql_column_strategy,
112 )
113 .await?;
114
115 let incoming_sink_ids: HashSet<_> = old_catalog.incoming_sinks.iter().copied().collect();
116
117 let target_columns = (table.columns.iter())
118 .filter(|col| !col.is_rw_timestamp_column())
119 .cloned()
120 .collect_vec();
121
122 for sink in fetch_incoming_sinks(session, &incoming_sink_ids)? {
123 hijack_merger_for_target_table(
124 &mut graph,
125 &target_columns,
126 &sink,
127 Some(&sink.unique_identity()),
128 )?;
129 }
130
131 let mut table = table;
133 table.incoming_sinks = incoming_sink_ids.iter().copied().collect();
134 table.vnode_count = VnodeCount::set(old_catalog.vnode_count());
135
136 Ok((source, table, graph, job_type))
137}
138
139pub(crate) fn hijack_merger_for_target_table(
140 graph: &mut StreamFragmentGraph,
141 target_columns: &[ColumnCatalog],
142 sink: &SinkCatalog,
143 uniq_identify: Option<&str>,
144) -> Result<()> {
145 let mut sink_columns = sink.original_target_columns.clone();
146 if sink_columns.is_empty() {
147 sink_columns = target_columns.to_vec();
152 }
153
154 let mut exprs = Vec::with_capacity(target_columns.len());
155 let sink_idx_by_col_id = sink_columns
156 .iter()
157 .enumerate()
158 .map(|(idx, col)| (col.column_id(), idx))
159 .collect::<HashMap<_, _>>();
160 let default_column_exprs = TableCatalog::default_column_exprs(target_columns);
161 for (target_idx, target_col) in target_columns.iter().enumerate() {
162 if let Some(idx) = sink_idx_by_col_id.get(&target_col.column_id()) {
163 assert_eq!(
164 target_col.data_type(),
165 sink_columns[*idx].data_type(),
166 "data type mismatch for column {}: {} vs {}",
167 target_col.name(),
168 target_col.data_type(),
169 sink_columns[*idx].data_type()
170 );
171 exprs.push(ExprImpl::InputRef(Box::new(InputRef {
173 data_type: target_col.data_type().clone(),
174 index: *idx,
175 })));
176 } else {
177 exprs.push(default_column_exprs[target_idx].clone());
179 }
180 }
181
182 let pb_project = PbNodeBody::Project(Box::new(ProjectNode {
183 select_list: exprs.iter().map(|expr| expr.to_expr_proto()).collect(),
184 ..Default::default()
185 }));
186
187 for fragment in graph.fragments.values_mut() {
188 if let Some(node) = &mut fragment.node {
189 insert_merger_to_union_with_project(node, &pb_project, uniq_identify);
190 }
191 }
192
193 Ok(())
194}
195
196pub async fn handle_alter_table_column(
199 handler_args: HandlerArgs,
200 table_name: ObjectName,
201 operation: AlterTableOperation,
202) -> Result<RwPgResponse> {
203 let session = handler_args.session;
204 let original_catalog = fetch_table_catalog_for_alter(session.as_ref(), &table_name)?;
205
206 if !original_catalog.incoming_sinks.is_empty() && original_catalog.has_generated_column() {
207 return Err(RwError::from(ErrorCode::BindError(
208 "Alter a table with incoming sink and generated column has not been implemented."
209 .to_owned(),
210 )));
211 }
212
213 if original_catalog.webhook_info.is_some() {
214 return Err(RwError::from(ErrorCode::BindError(
215 "Adding/dropping a column of a table with webhook has not been implemented.".to_owned(),
216 )));
217 }
218
219 let mut definition = original_catalog.create_sql_ast_purified()?;
221 let Statement::CreateTable { columns, .. } = &mut definition else {
222 panic!("unexpected statement: {:?}", definition);
223 };
224
225 if !original_catalog.incoming_sinks.is_empty()
226 && matches!(operation, AlterTableOperation::DropColumn { .. })
227 {
228 return Err(ErrorCode::InvalidInputSyntax(
229 "dropping columns in target table of sinks is not supported".to_owned(),
230 ))?;
231 }
232
233 let sql_column_strategy = match operation {
251 AlterTableOperation::AddColumn {
252 column_def: new_column,
253 } => {
254 let new_column_name = new_column.name.real_value();
257 if columns
258 .iter()
259 .any(|c| c.name.real_value() == new_column_name)
260 {
261 Err(ErrorCode::InvalidInputSyntax(format!(
262 "column \"{new_column_name}\" of table \"{table_name}\" already exists"
263 )))?
264 }
265
266 if new_column
267 .options
268 .iter()
269 .any(|x| matches!(x.option, ColumnOption::GeneratedColumns(_)))
270 {
271 Err(ErrorCode::InvalidInputSyntax(
272 "alter table add generated columns is not supported".to_owned(),
273 ))?
274 }
275
276 if new_column
277 .options
278 .iter()
279 .any(|x| matches!(x.option, ColumnOption::NotNull))
280 && !new_column
281 .options
282 .iter()
283 .any(|x| matches!(x.option, ColumnOption::DefaultValue(_)))
284 {
285 return Err(ErrorCode::InvalidInputSyntax(
286 "alter table add NOT NULL columns must have default value".to_owned(),
287 ))?;
288 }
289
290 columns.push(new_column);
292
293 SqlColumnStrategy::FollowChecked
294 }
295
296 AlterTableOperation::DropColumn {
297 column_name,
298 if_exists,
299 cascade,
300 } => {
301 if cascade {
302 bail_not_implemented!(issue = 6903, "drop column cascade");
303 }
304
305 for column in original_catalog.columns() {
307 if let Some(expr) = column.generated_expr() {
308 let expr = ExprImpl::from_expr_proto(expr)?;
309 let refs = expr.collect_input_refs(original_catalog.columns().len());
310 for idx in refs.ones() {
311 let refed_column = &original_catalog.columns()[idx];
312 if refed_column.name() == column_name.real_value() {
313 bail!(format!(
314 "failed to drop column \"{}\" because it's referenced by a generated column \"{}\"",
315 column_name,
316 column.name()
317 ))
318 }
319 }
320 }
321 }
322
323 let column_name = column_name.real_value();
325 let removed_column = columns
326 .extract_if(.., |c| c.name.real_value() == column_name)
327 .at_most_one()
328 .ok()
329 .unwrap();
330
331 if removed_column.is_some() {
332 } else if if_exists {
334 return Ok(PgResponse::builder(StatementType::ALTER_TABLE)
335 .notice(format!(
336 "column \"{}\" does not exist, skipping",
337 column_name
338 ))
339 .into());
340 } else {
341 Err(ErrorCode::InvalidInputSyntax(format!(
342 "column \"{}\" of table \"{}\" does not exist",
343 column_name, table_name
344 )))?
345 }
346
347 SqlColumnStrategy::FollowUnchecked
348 }
349
350 AlterTableOperation::AlterColumn { column_name, op } => {
351 let AlterColumnOperation::SetDataType {
352 data_type,
353 using: None,
354 } = op
355 else {
356 bail_not_implemented!(issue = 6903, "{op}");
357 };
358
359 let column_name = column_name.real_value();
361 let column = columns
362 .iter_mut()
363 .find(|c| c.name.real_value() == column_name)
364 .ok_or_else(|| {
365 ErrorCode::InvalidInputSyntax(format!(
366 "column \"{}\" of table \"{}\" does not exist",
367 column_name, table_name
368 ))
369 })?;
370
371 column.data_type = Some(data_type);
372
373 SqlColumnStrategy::FollowChecked
374 }
375
376 _ => unreachable!(),
377 };
378 let (source, table, graph, job_type) = get_replace_table_plan(
379 &session,
380 table_name,
381 definition,
382 &original_catalog,
383 sql_column_strategy,
384 )
385 .await?;
386
387 let catalog_writer = session.catalog_writer()?;
388
389 catalog_writer
390 .replace_table(
391 source.map(|x| x.to_prost()),
392 table.to_prost(),
393 graph,
394 job_type,
395 )
396 .await?;
397 Ok(PgResponse::empty_result(StatementType::ALTER_TABLE))
398}
399
400pub fn fetch_table_catalog_for_alter(
401 session: &SessionImpl,
402 table_name: &ObjectName,
403) -> Result<Arc<TableCatalog>> {
404 let db_name = &session.database();
405 let (schema_name, real_table_name) =
406 Binder::resolve_schema_qualified_name(db_name, table_name)?;
407 let search_path = session.config().search_path();
408 let user_name = &session.user_name();
409
410 let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
411
412 let original_catalog = {
413 let reader = session.env().catalog_reader().read_guard();
414 let (table, schema_name) =
415 reader.get_created_table_by_name(db_name, schema_path, &real_table_name)?;
416
417 match table.table_type() {
418 TableType::Table => {}
419
420 _ => Err(ErrorCode::InvalidInputSyntax(format!(
421 "\"{table_name}\" is not a table or cannot be altered"
422 )))?,
423 }
424
425 session.check_privilege_for_drop_alter(schema_name, &**table)?;
426
427 table.clone()
428 };
429
430 Ok(original_catalog)
431}
432
433#[cfg(test)]
434mod tests {
435 use std::collections::HashMap;
436
437 use risingwave_common::catalog::{
438 DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, ROW_ID_COLUMN_NAME,
439 };
440 use risingwave_common::types::DataType;
441
442 use crate::catalog::root_catalog::SchemaPath;
443 use crate::test_utils::LocalFrontend;
444
445 #[tokio::test]
446 async fn test_add_column_handler() {
447 let frontend = LocalFrontend::new(Default::default()).await;
448 let session = frontend.session_ref();
449 let schema_path = SchemaPath::Name(DEFAULT_SCHEMA_NAME);
450
451 let sql = "create table t (i int, r real);";
452 frontend.run_sql(sql).await.unwrap();
453
454 let get_table = || {
455 let catalog_reader = session.env().catalog_reader().read_guard();
456 catalog_reader
457 .get_created_table_by_name(DEFAULT_DATABASE_NAME, schema_path, "t")
458 .unwrap()
459 .0
460 .clone()
461 };
462
463 let table = get_table();
464
465 let columns: HashMap<_, _> = table
466 .columns
467 .iter()
468 .map(|col| (col.name(), (col.data_type().clone(), col.column_id())))
469 .collect();
470
471 let sql = "alter table t add column s text;";
473 frontend.run_sql(sql).await.unwrap();
474
475 let altered_table = get_table();
476
477 let altered_columns: HashMap<_, _> = altered_table
478 .columns
479 .iter()
480 .map(|col| (col.name(), (col.data_type().clone(), col.column_id())))
481 .collect();
482
483 assert_eq!(columns.len() + 1, altered_columns.len());
485 assert_eq!(altered_columns["s"].0, DataType::Varchar);
486
487 assert_eq!(columns["i"], altered_columns["i"]);
489 assert_eq!(columns["r"], altered_columns["r"]);
490 assert_eq!(
491 columns[ROW_ID_COLUMN_NAME],
492 altered_columns[ROW_ID_COLUMN_NAME]
493 );
494
495 assert_eq!(
497 table.version.as_ref().unwrap().version_id + 1,
498 altered_table.version.as_ref().unwrap().version_id
499 );
500 assert_eq!(
501 table.version.as_ref().unwrap().next_column_id.next(),
502 altered_table.version.as_ref().unwrap().next_column_id
503 );
504 }
505}