risingwave_frontend/handler/
alter_set_schema.rsuse pgwire::pg_response::StatementType;
use risingwave_pb::ddl_service::alter_set_schema_request::Object;
use risingwave_sqlparser::ast::{ObjectName, OperateFunctionArg};
use super::{HandlerArgs, RwPgResponse};
use crate::catalog::root_catalog::SchemaPath;
use crate::error::{ErrorCode, Result};
use crate::{bind_data_type, Binder};
pub async fn handle_alter_set_schema(
handler_args: HandlerArgs,
obj_name: ObjectName,
new_schema_name: ObjectName,
stmt_type: StatementType,
func_args: Option<Vec<OperateFunctionArg>>,
) -> Result<RwPgResponse> {
let session = handler_args.session;
let db_name = session.database();
let (schema_name, real_obj_name) =
Binder::resolve_schema_qualified_name(db_name, obj_name.clone())?;
let search_path = session.config().search_path();
let user_name = &session.auth_context().user_name;
let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name);
let new_schema_name = Binder::resolve_schema_name(new_schema_name)?;
let object = {
let catalog_reader = session.env().catalog_reader().read_guard();
match stmt_type {
StatementType::ALTER_TABLE | StatementType::ALTER_MATERIALIZED_VIEW => {
let (table, old_schema_name) = catalog_reader.get_created_table_by_name(
db_name,
schema_path,
&real_obj_name,
)?;
if old_schema_name == new_schema_name {
return Ok(RwPgResponse::empty_result(stmt_type));
}
session.check_privilege_for_drop_alter(old_schema_name, &**table)?;
catalog_reader.check_relation_name_duplicated(
db_name,
&new_schema_name,
table.name(),
)?;
Object::TableId(table.id.table_id)
}
StatementType::ALTER_VIEW => {
let (view, old_schema_name) =
catalog_reader.get_view_by_name(db_name, schema_path, &real_obj_name)?;
if old_schema_name == new_schema_name {
return Ok(RwPgResponse::empty_result(stmt_type));
}
session.check_privilege_for_drop_alter(old_schema_name, &**view)?;
catalog_reader.check_relation_name_duplicated(
db_name,
&new_schema_name,
view.name(),
)?;
Object::ViewId(view.id)
}
StatementType::ALTER_SOURCE => {
let (source, old_schema_name) =
catalog_reader.get_source_by_name(db_name, schema_path, &real_obj_name)?;
if old_schema_name == new_schema_name {
return Ok(RwPgResponse::empty_result(stmt_type));
}
session.check_privilege_for_drop_alter(old_schema_name, &**source)?;
catalog_reader.check_relation_name_duplicated(
db_name,
&new_schema_name,
&source.name,
)?;
Object::SourceId(source.id)
}
StatementType::ALTER_SINK => {
let (sink, old_schema_name) =
catalog_reader.get_sink_by_name(db_name, schema_path, &real_obj_name)?;
if old_schema_name == new_schema_name {
return Ok(RwPgResponse::empty_result(stmt_type));
}
session.check_privilege_for_drop_alter(old_schema_name, &**sink)?;
catalog_reader.check_relation_name_duplicated(
db_name,
&new_schema_name,
&sink.name,
)?;
Object::SinkId(sink.id.sink_id)
}
StatementType::ALTER_SUBSCRIPTION => {
let (subscription, old_schema_name) = catalog_reader.get_subscription_by_name(
db_name,
schema_path,
&real_obj_name,
)?;
if old_schema_name == new_schema_name {
return Ok(RwPgResponse::empty_result(stmt_type));
}
session.check_privilege_for_drop_alter(old_schema_name, &**subscription)?;
catalog_reader.check_relation_name_duplicated(
db_name,
&new_schema_name,
&subscription.name,
)?;
Object::SubscriptionId(subscription.id.subscription_id)
}
StatementType::ALTER_CONNECTION => {
let (connection, old_schema_name) =
catalog_reader.get_connection_by_name(db_name, schema_path, &real_obj_name)?;
if old_schema_name == new_schema_name {
return Ok(RwPgResponse::empty_result(stmt_type));
}
session.check_privilege_for_drop_alter(old_schema_name, &**connection)?;
catalog_reader.check_connection_name_duplicated(
db_name,
&new_schema_name,
&connection.name,
)?;
Object::ConnectionId(connection.id)
}
StatementType::ALTER_FUNCTION => {
let (function, old_schema_name) = if let Some(args) = func_args {
let mut arg_types = Vec::with_capacity(args.len());
for arg in args {
arg_types.push(bind_data_type(&arg.data_type)?);
}
catalog_reader.get_function_by_name_args(
db_name,
schema_path,
&real_obj_name,
&arg_types,
)?
} else {
let (functions, old_schema_name) = catalog_reader.get_functions_by_name(
db_name,
schema_path,
&real_obj_name,
)?;
if functions.len() > 1 {
return Err(ErrorCode::CatalogError(format!("function name {real_obj_name:?} is not unique\nHINT: Specify the argument list to select the function unambiguously.").into()).into());
}
(
functions.into_iter().next().expect("no functions"),
old_schema_name,
)
};
if old_schema_name == new_schema_name {
return Ok(RwPgResponse::empty_result(stmt_type));
}
session.check_privilege_for_drop_alter(old_schema_name, &**function)?;
catalog_reader.check_function_name_duplicated(
db_name,
&new_schema_name,
&function.name,
&function.arg_types,
)?;
Object::FunctionId(function.id.function_id())
}
_ => unreachable!(),
}
};
let (_, new_schema_id) =
session.get_database_and_schema_id_for_create(Some(new_schema_name))?;
let catalog_writer = session.catalog_writer()?;
catalog_writer
.alter_set_schema(object, new_schema_id)
.await?;
Ok(RwPgResponse::empty_result(stmt_type))
}
#[cfg(test)]
pub mod tests {
use risingwave_common::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME};
use crate::catalog::root_catalog::SchemaPath;
use crate::test_utils::LocalFrontend;
#[tokio::test]
async fn test_alter_set_schema_handler() {
let frontend = LocalFrontend::new(Default::default()).await;
let session = frontend.session_ref();
let sql = "CREATE TABLE test_table (u INT, v INT);";
frontend.run_sql(sql).await.unwrap();
let sql = "CREATE SCHEMA test_schema;";
frontend.run_sql(sql).await.unwrap();
let get_table = |schema_name| {
let catalog_reader = session.env().catalog_reader().read_guard();
let schema_path = SchemaPath::Name(schema_name);
catalog_reader
.get_created_table_by_name(DEFAULT_DATABASE_NAME, schema_path, "test_table")
.unwrap()
.0
.clone()
};
let get_schema_name_by_table_id = |table_id| {
let catalog_reader = session.env().catalog_reader().read_guard();
catalog_reader
.get_schema_by_table_id(DEFAULT_DATABASE_NAME, &table_id)
.unwrap()
.name()
};
let old_schema_name = get_schema_name_by_table_id(get_table(DEFAULT_SCHEMA_NAME).id);
assert_eq!(old_schema_name, "public");
let sql = "ALTER TABLE test_table SET SCHEMA test_schema;";
frontend.run_sql(sql).await.unwrap();
let new_schema_name = get_schema_name_by_table_id(get_table("test_schema").id);
assert_eq!(new_schema_name, "test_schema");
}
}