risingwave_meta/controller/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::{BTreeSet, HashMap, HashSet};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
23use risingwave_common::hash::{ActorMapping, VnodeBitmapExt, WorkerSlotId, WorkerSlotMapping};
24use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
25use risingwave_common::{bail, hash};
26use risingwave_meta_model::actor::ActorStatus;
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::object::ObjectType;
29use risingwave_meta_model::prelude::*;
30use risingwave_meta_model::table::TableType;
31use risingwave_meta_model::user_privilege::Action;
32use risingwave_meta_model::{
33    ActorId, DataTypeArray, DatabaseId, DispatcherType, FragmentId, I32Array, JobStatus, ObjectId,
34    PrivilegeId, SchemaId, SourceId, StreamNode, StreamSourceInfo, TableId, UserId, VnodeBitmap,
35    WorkerId, actor, connection, database, fragment, fragment_relation, function, index, object,
36    object_dependency, schema, secret, sink, source, streaming_job, subscription, table, user,
37    user_default_privilege, user_privilege, view,
38};
39use risingwave_meta_model_migration::WithQuery;
40use risingwave_pb::catalog::{
41    PbConnection, PbDatabase, PbFunction, PbIndex, PbSchema, PbSecret, PbSink, PbSource,
42    PbSubscription, PbTable, PbView,
43};
44use risingwave_pb::common::WorkerNode;
45use risingwave_pb::meta::object::PbObjectInfo;
46use risingwave_pb::meta::subscribe_response::Info as NotificationInfo;
47use risingwave_pb::meta::{
48    FragmentWorkerSlotMapping, PbFragmentWorkerSlotMapping, PbObject, PbObjectGroup,
49};
50use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher, PbDispatcherType};
51use risingwave_pb::user::grant_privilege::{PbActionWithGrantOption, PbObject as PbGrantObject};
52use risingwave_pb::user::{PbAction, PbGrantPrivilege, PbUserInfo};
53use risingwave_sqlparser::ast::Statement as SqlStatement;
54use risingwave_sqlparser::parser::Parser;
55use sea_orm::sea_query::{
56    Alias, CommonTableExpression, Expr, Query, QueryStatementBuilder, SelectStatement, UnionType,
57    WithClause,
58};
59use sea_orm::{
60    ColumnTrait, ConnectionTrait, DatabaseTransaction, DerivePartialModel, EntityTrait,
61    FromQueryResult, IntoActiveModel, JoinType, Order, PaginatorTrait, QueryFilter, QuerySelect,
62    RelationTrait, Set, Statement,
63};
64use thiserror_ext::AsReport;
65
66use crate::controller::ObjectModel;
67use crate::model::{FragmentActorDispatchers, FragmentDownstreamRelation};
68use crate::{MetaError, MetaResult};
69
70/// This function will construct a query using recursive cte to find all objects[(id, `obj_type`)] that are used by the given object.
71///
72/// # Examples
73///
74/// ```
75/// use risingwave_meta::controller::utils::construct_obj_dependency_query;
76/// use sea_orm::sea_query::*;
77/// use sea_orm::*;
78///
79/// let query = construct_obj_dependency_query(1);
80///
81/// assert_eq!(
82///     query.to_string(MysqlQueryBuilder),
83///     r#"WITH RECURSIVE `used_by_object_ids` (`used_by`) AS (SELECT `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `oid` FROM `object` WHERE `object`.`database_id` = 1 OR `object`.`schema_id` = 1) UNION ALL (SELECT `object_dependency`.`used_by` FROM `object_dependency` INNER JOIN `used_by_object_ids` ON `used_by_object_ids`.`used_by` = `oid`)) SELECT DISTINCT `oid`, `obj_type`, `schema_id`, `database_id` FROM `used_by_object_ids` INNER JOIN `object` ON `used_by_object_ids`.`used_by` = `oid` ORDER BY `oid` DESC"#
84/// );
85/// assert_eq!(
86///     query.to_string(PostgresQueryBuilder),
87///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1) UNION ALL (SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid")) SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
88/// );
89/// assert_eq!(
90///     query.to_string(SqliteQueryBuilder),
91///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1 UNION ALL SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid") SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
92/// );
93/// ```
94pub fn construct_obj_dependency_query(obj_id: ObjectId) -> WithQuery {
95    let cte_alias = Alias::new("used_by_object_ids");
96    let cte_return_alias = Alias::new("used_by");
97
98    let mut base_query = SelectStatement::new()
99        .column(object_dependency::Column::UsedBy)
100        .from(ObjectDependency)
101        .and_where(object_dependency::Column::Oid.eq(obj_id))
102        .to_owned();
103
104    let belonged_obj_query = SelectStatement::new()
105        .column(object::Column::Oid)
106        .from(Object)
107        .and_where(
108            object::Column::DatabaseId
109                .eq(obj_id)
110                .or(object::Column::SchemaId.eq(obj_id)),
111        )
112        .to_owned();
113
114    let cte_referencing = Query::select()
115        .column((ObjectDependency, object_dependency::Column::UsedBy))
116        .from(ObjectDependency)
117        .inner_join(
118            cte_alias.clone(),
119            Expr::col((cte_alias.clone(), cte_return_alias.clone()))
120                .equals(object_dependency::Column::Oid),
121        )
122        .to_owned();
123
124    let common_table_expr = CommonTableExpression::new()
125        .query(
126            base_query
127                .union(UnionType::All, belonged_obj_query)
128                .union(UnionType::All, cte_referencing)
129                .to_owned(),
130        )
131        .column(cte_return_alias.clone())
132        .table_name(cte_alias.clone())
133        .to_owned();
134
135    SelectStatement::new()
136        .distinct()
137        .columns([
138            object::Column::Oid,
139            object::Column::ObjType,
140            object::Column::SchemaId,
141            object::Column::DatabaseId,
142        ])
143        .from(cte_alias.clone())
144        .inner_join(
145            Object,
146            Expr::col((cte_alias, cte_return_alias.clone())).equals(object::Column::Oid),
147        )
148        .order_by(object::Column::Oid, Order::Desc)
149        .to_owned()
150        .with(
151            WithClause::new()
152                .recursive(true)
153                .cte(common_table_expr)
154                .to_owned(),
155        )
156        .to_owned()
157}
158
159/// This function will construct a query using recursive cte to find if dependent objects are already relying on the target table.
160///
161/// # Examples
162///
163/// ```
164/// use risingwave_meta::controller::utils::construct_sink_cycle_check_query;
165/// use sea_orm::sea_query::*;
166/// use sea_orm::*;
167///
168/// let query = construct_sink_cycle_check_query(1, vec![2, 3]);
169///
170/// assert_eq!(
171///     query.to_string(MysqlQueryBuilder),
172///     r#"WITH RECURSIVE `used_by_object_ids_with_sink` (`oid`, `used_by`) AS (SELECT `oid`, `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `obj_dependency_with_sink`.`oid`, `obj_dependency_with_sink`.`used_by` FROM (SELECT `oid`, `used_by` FROM `object_dependency` UNION ALL (SELECT `sink_id`, `target_table` FROM `sink` WHERE `sink`.`target_table` IS NOT NULL)) AS `obj_dependency_with_sink` INNER JOIN `used_by_object_ids_with_sink` ON `used_by_object_ids_with_sink`.`used_by` = `obj_dependency_with_sink`.`oid` WHERE `used_by_object_ids_with_sink`.`used_by` <> `used_by_object_ids_with_sink`.`oid`)) SELECT COUNT(`used_by_object_ids_with_sink`.`used_by`) FROM `used_by_object_ids_with_sink` WHERE `used_by_object_ids_with_sink`.`used_by` IN (2, 3)"#
173/// );
174/// assert_eq!(
175///     query.to_string(PostgresQueryBuilder),
176///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL (SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL)) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid")) SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
177/// );
178/// assert_eq!(
179///     query.to_string(SqliteQueryBuilder),
180///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid") SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
181/// );
182/// ```
183pub fn construct_sink_cycle_check_query(
184    target_table: ObjectId,
185    dependent_objects: Vec<ObjectId>,
186) -> WithQuery {
187    let cte_alias = Alias::new("used_by_object_ids_with_sink");
188    let depend_alias = Alias::new("obj_dependency_with_sink");
189
190    let mut base_query = SelectStatement::new()
191        .columns([
192            object_dependency::Column::Oid,
193            object_dependency::Column::UsedBy,
194        ])
195        .from(ObjectDependency)
196        .and_where(object_dependency::Column::Oid.eq(target_table))
197        .to_owned();
198
199    let query_sink_deps = SelectStatement::new()
200        .columns([sink::Column::SinkId, sink::Column::TargetTable])
201        .from(Sink)
202        .and_where(sink::Column::TargetTable.is_not_null())
203        .to_owned();
204
205    let cte_referencing = Query::select()
206        .column((depend_alias.clone(), object_dependency::Column::Oid))
207        .column((depend_alias.clone(), object_dependency::Column::UsedBy))
208        .from_subquery(
209            SelectStatement::new()
210                .columns([
211                    object_dependency::Column::Oid,
212                    object_dependency::Column::UsedBy,
213                ])
214                .from(ObjectDependency)
215                .union(UnionType::All, query_sink_deps)
216                .to_owned(),
217            depend_alias.clone(),
218        )
219        .inner_join(
220            cte_alias.clone(),
221            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).eq(Expr::col((
222                depend_alias.clone(),
223                object_dependency::Column::Oid,
224            ))),
225        )
226        .and_where(
227            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).ne(Expr::col((
228                cte_alias.clone(),
229                object_dependency::Column::Oid,
230            ))),
231        )
232        .to_owned();
233
234    let common_table_expr = CommonTableExpression::new()
235        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
236        .columns([
237            object_dependency::Column::Oid,
238            object_dependency::Column::UsedBy,
239        ])
240        .table_name(cte_alias.clone())
241        .to_owned();
242
243    SelectStatement::new()
244        .expr(Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).count())
245        .from(cte_alias.clone())
246        .and_where(
247            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy))
248                .is_in(dependent_objects),
249        )
250        .to_owned()
251        .with(
252            WithClause::new()
253                .recursive(true)
254                .cte(common_table_expr)
255                .to_owned(),
256        )
257        .to_owned()
258}
259
260#[derive(Clone, DerivePartialModel, FromQueryResult, Debug)]
261#[sea_orm(entity = "Object")]
262pub struct PartialObject {
263    pub oid: ObjectId,
264    pub obj_type: ObjectType,
265    pub schema_id: Option<SchemaId>,
266    pub database_id: Option<DatabaseId>,
267}
268
269#[derive(Clone, DerivePartialModel, FromQueryResult)]
270#[sea_orm(entity = "Fragment")]
271pub struct PartialFragmentStateTables {
272    pub fragment_id: FragmentId,
273    pub job_id: ObjectId,
274    pub state_table_ids: I32Array,
275}
276
277#[derive(Clone, DerivePartialModel, FromQueryResult)]
278#[sea_orm(entity = "Actor")]
279pub struct PartialActorLocation {
280    pub actor_id: ActorId,
281    pub fragment_id: FragmentId,
282    pub worker_id: WorkerId,
283    pub status: ActorStatus,
284}
285
286#[derive(FromQueryResult)]
287pub struct FragmentDesc {
288    pub fragment_id: FragmentId,
289    pub job_id: ObjectId,
290    pub fragment_type_mask: i32,
291    pub distribution_type: DistributionType,
292    pub state_table_ids: I32Array,
293    pub parallelism: i64,
294    pub vnode_count: i32,
295    pub stream_node: StreamNode,
296}
297
298/// List all objects that are using the given one in a cascade way. It runs a recursive CTE to find all the dependencies.
299pub async fn get_referring_objects_cascade<C>(
300    obj_id: ObjectId,
301    db: &C,
302) -> MetaResult<Vec<PartialObject>>
303where
304    C: ConnectionTrait,
305{
306    let query = construct_obj_dependency_query(obj_id);
307    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
308    let objects = PartialObject::find_by_statement(Statement::from_sql_and_values(
309        db.get_database_backend(),
310        sql,
311        values,
312    ))
313    .all(db)
314    .await?;
315    Ok(objects)
316}
317
318/// Check if create a sink with given dependent objects into the target table will cause a cycle, return true if it will.
319pub async fn check_sink_into_table_cycle<C>(
320    target_table: ObjectId,
321    dependent_objs: Vec<ObjectId>,
322    db: &C,
323) -> MetaResult<bool>
324where
325    C: ConnectionTrait,
326{
327    if dependent_objs.is_empty() {
328        return Ok(false);
329    }
330
331    // special check for self referencing
332    if dependent_objs.contains(&target_table) {
333        return Ok(true);
334    }
335
336    let query = construct_sink_cycle_check_query(target_table, dependent_objs);
337    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
338
339    let res = db
340        .query_one(Statement::from_sql_and_values(
341            db.get_database_backend(),
342            sql,
343            values,
344        ))
345        .await?
346        .unwrap();
347
348    let cnt: i64 = res.try_get_by(0)?;
349
350    Ok(cnt != 0)
351}
352
353/// `ensure_object_id` ensures the existence of target object in the cluster.
354pub async fn ensure_object_id<C>(
355    object_type: ObjectType,
356    obj_id: ObjectId,
357    db: &C,
358) -> MetaResult<()>
359where
360    C: ConnectionTrait,
361{
362    let count = Object::find_by_id(obj_id).count(db).await?;
363    if count == 0 {
364        return Err(MetaError::catalog_id_not_found(
365            object_type.as_str(),
366            obj_id,
367        ));
368    }
369    Ok(())
370}
371
372/// `ensure_user_id` ensures the existence of target user in the cluster.
373pub async fn ensure_user_id<C>(user_id: UserId, db: &C) -> MetaResult<()>
374where
375    C: ConnectionTrait,
376{
377    let count = User::find_by_id(user_id).count(db).await?;
378    if count == 0 {
379        return Err(anyhow!("user {} was concurrently dropped", user_id).into());
380    }
381    Ok(())
382}
383
384/// `check_database_name_duplicate` checks whether the database name is already used in the cluster.
385pub async fn check_database_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
386where
387    C: ConnectionTrait,
388{
389    let count = Database::find()
390        .filter(database::Column::Name.eq(name))
391        .count(db)
392        .await?;
393    if count > 0 {
394        assert_eq!(count, 1);
395        return Err(MetaError::catalog_duplicated("database", name));
396    }
397    Ok(())
398}
399
400/// `check_function_signature_duplicate` checks whether the function name and its signature is already used in the target namespace.
401pub async fn check_function_signature_duplicate<C>(
402    pb_function: &PbFunction,
403    db: &C,
404) -> MetaResult<()>
405where
406    C: ConnectionTrait,
407{
408    let count = Function::find()
409        .inner_join(Object)
410        .filter(
411            object::Column::DatabaseId
412                .eq(pb_function.database_id as DatabaseId)
413                .and(object::Column::SchemaId.eq(pb_function.schema_id as SchemaId))
414                .and(function::Column::Name.eq(&pb_function.name))
415                .and(
416                    function::Column::ArgTypes
417                        .eq(DataTypeArray::from(pb_function.arg_types.clone())),
418                ),
419        )
420        .count(db)
421        .await?;
422    if count > 0 {
423        assert_eq!(count, 1);
424        return Err(MetaError::catalog_duplicated("function", &pb_function.name));
425    }
426    Ok(())
427}
428
429/// `check_connection_name_duplicate` checks whether the connection name is already used in the target namespace.
430pub async fn check_connection_name_duplicate<C>(
431    pb_connection: &PbConnection,
432    db: &C,
433) -> MetaResult<()>
434where
435    C: ConnectionTrait,
436{
437    let count = Connection::find()
438        .inner_join(Object)
439        .filter(
440            object::Column::DatabaseId
441                .eq(pb_connection.database_id as DatabaseId)
442                .and(object::Column::SchemaId.eq(pb_connection.schema_id as SchemaId))
443                .and(connection::Column::Name.eq(&pb_connection.name)),
444        )
445        .count(db)
446        .await?;
447    if count > 0 {
448        assert_eq!(count, 1);
449        return Err(MetaError::catalog_duplicated(
450            "connection",
451            &pb_connection.name,
452        ));
453    }
454    Ok(())
455}
456
457pub async fn check_secret_name_duplicate<C>(pb_secret: &PbSecret, db: &C) -> MetaResult<()>
458where
459    C: ConnectionTrait,
460{
461    let count = Secret::find()
462        .inner_join(Object)
463        .filter(
464            object::Column::DatabaseId
465                .eq(pb_secret.database_id as DatabaseId)
466                .and(object::Column::SchemaId.eq(pb_secret.schema_id as SchemaId))
467                .and(secret::Column::Name.eq(&pb_secret.name)),
468        )
469        .count(db)
470        .await?;
471    if count > 0 {
472        assert_eq!(count, 1);
473        return Err(MetaError::catalog_duplicated("secret", &pb_secret.name));
474    }
475    Ok(())
476}
477
478pub async fn check_subscription_name_duplicate<C>(
479    pb_subscription: &PbSubscription,
480    db: &C,
481) -> MetaResult<()>
482where
483    C: ConnectionTrait,
484{
485    let count = Subscription::find()
486        .inner_join(Object)
487        .filter(
488            object::Column::DatabaseId
489                .eq(pb_subscription.database_id as DatabaseId)
490                .and(object::Column::SchemaId.eq(pb_subscription.schema_id as SchemaId))
491                .and(subscription::Column::Name.eq(&pb_subscription.name)),
492        )
493        .count(db)
494        .await?;
495    if count > 0 {
496        assert_eq!(count, 1);
497        return Err(MetaError::catalog_duplicated(
498            "subscription",
499            &pb_subscription.name,
500        ));
501    }
502    Ok(())
503}
504
505/// `check_user_name_duplicate` checks whether the user is already existed in the cluster.
506pub async fn check_user_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
507where
508    C: ConnectionTrait,
509{
510    let count = User::find()
511        .filter(user::Column::Name.eq(name))
512        .count(db)
513        .await?;
514    if count > 0 {
515        assert_eq!(count, 1);
516        return Err(MetaError::catalog_duplicated("user", name));
517    }
518    Ok(())
519}
520
521/// `check_relation_name_duplicate` checks whether the relation name is already used in the target namespace.
522pub async fn check_relation_name_duplicate<C>(
523    name: &str,
524    database_id: DatabaseId,
525    schema_id: SchemaId,
526    db: &C,
527) -> MetaResult<()>
528where
529    C: ConnectionTrait,
530{
531    macro_rules! check_duplicated {
532        ($obj_type:expr, $entity:ident, $table:ident) => {
533            let object_id = Object::find()
534                .select_only()
535                .column(object::Column::Oid)
536                .inner_join($entity)
537                .filter(
538                    object::Column::DatabaseId
539                        .eq(Some(database_id))
540                        .and(object::Column::SchemaId.eq(Some(schema_id)))
541                        .and($table::Column::Name.eq(name)),
542                )
543                .into_tuple::<ObjectId>()
544                .one(db)
545                .await?;
546            if let Some(oid) = object_id {
547                let check_creation = if $obj_type == ObjectType::View {
548                    false
549                } else if $obj_type == ObjectType::Source {
550                    let source_info = Source::find_by_id(oid)
551                        .select_only()
552                        .column(source::Column::SourceInfo)
553                        .into_tuple::<Option<StreamSourceInfo>>()
554                        .one(db)
555                        .await?
556                        .unwrap();
557                    source_info.map_or(false, |info| info.to_protobuf().is_shared())
558                } else {
559                    true
560                };
561                return if check_creation
562                    && !matches!(
563                        StreamingJob::find_by_id(oid)
564                            .select_only()
565                            .column(streaming_job::Column::JobStatus)
566                            .into_tuple::<JobStatus>()
567                            .one(db)
568                            .await?,
569                        Some(JobStatus::Created)
570                    ) {
571                    Err(MetaError::catalog_under_creation(
572                        $obj_type.as_str(),
573                        name,
574                        oid,
575                    ))
576                } else {
577                    Err(MetaError::catalog_duplicated($obj_type.as_str(), name))
578                };
579            }
580        };
581    }
582    check_duplicated!(ObjectType::Table, Table, table);
583    check_duplicated!(ObjectType::Source, Source, source);
584    check_duplicated!(ObjectType::Sink, Sink, sink);
585    check_duplicated!(ObjectType::Index, Index, index);
586    check_duplicated!(ObjectType::View, View, view);
587
588    Ok(())
589}
590
591/// `check_schema_name_duplicate` checks whether the schema name is already used in the target database.
592pub async fn check_schema_name_duplicate<C>(
593    name: &str,
594    database_id: DatabaseId,
595    db: &C,
596) -> MetaResult<()>
597where
598    C: ConnectionTrait,
599{
600    let count = Object::find()
601        .inner_join(Schema)
602        .filter(
603            object::Column::ObjType
604                .eq(ObjectType::Schema)
605                .and(object::Column::DatabaseId.eq(Some(database_id)))
606                .and(schema::Column::Name.eq(name)),
607        )
608        .count(db)
609        .await?;
610    if count != 0 {
611        return Err(MetaError::catalog_duplicated("schema", name));
612    }
613
614    Ok(())
615}
616
617/// `check_object_refer_for_drop` checks whether the object is used by other objects except indexes.
618/// It returns an error that contains the details of the referring objects if it is used by others.
619pub async fn check_object_refer_for_drop<C>(
620    object_type: ObjectType,
621    object_id: ObjectId,
622    db: &C,
623) -> MetaResult<()>
624where
625    C: ConnectionTrait,
626{
627    // Ignore indexes.
628    let count = if object_type == ObjectType::Table {
629        ObjectDependency::find()
630            .join(
631                JoinType::InnerJoin,
632                object_dependency::Relation::Object1.def(),
633            )
634            .filter(
635                object_dependency::Column::Oid
636                    .eq(object_id)
637                    .and(object::Column::ObjType.ne(ObjectType::Index)),
638            )
639            .count(db)
640            .await?
641    } else {
642        ObjectDependency::find()
643            .filter(object_dependency::Column::Oid.eq(object_id))
644            .count(db)
645            .await?
646    };
647    if count != 0 {
648        // find the name of all objects that are using the given one.
649        let referring_objects = get_referring_objects(object_id, db).await?;
650        let referring_objs_map = referring_objects
651            .into_iter()
652            .filter(|o| o.obj_type != ObjectType::Index)
653            .into_group_map_by(|o| o.obj_type);
654        let mut details = vec![];
655        for (obj_type, objs) in referring_objs_map {
656            match obj_type {
657                ObjectType::Table => {
658                    let tables: Vec<(String, String)> = Object::find()
659                        .join(JoinType::InnerJoin, object::Relation::Table.def())
660                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
661                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
662                        .select_only()
663                        .column(schema::Column::Name)
664                        .column(table::Column::Name)
665                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
666                        .into_tuple()
667                        .all(db)
668                        .await?;
669                    details.extend(tables.into_iter().map(|(schema_name, table_name)| {
670                        format!(
671                            "materialized view {}.{} depends on it",
672                            schema_name, table_name
673                        )
674                    }));
675                }
676                ObjectType::Sink => {
677                    let sinks: Vec<(String, String)> = Object::find()
678                        .join(JoinType::InnerJoin, object::Relation::Sink.def())
679                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
680                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
681                        .select_only()
682                        .column(schema::Column::Name)
683                        .column(sink::Column::Name)
684                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
685                        .into_tuple()
686                        .all(db)
687                        .await?;
688                    details.extend(sinks.into_iter().map(|(schema_name, sink_name)| {
689                        format!("sink {}.{} depends on it", schema_name, sink_name)
690                    }));
691                }
692                ObjectType::View => {
693                    let views: Vec<(String, String)> = Object::find()
694                        .join(JoinType::InnerJoin, object::Relation::View.def())
695                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
696                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
697                        .select_only()
698                        .column(schema::Column::Name)
699                        .column(view::Column::Name)
700                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
701                        .into_tuple()
702                        .all(db)
703                        .await?;
704                    details.extend(views.into_iter().map(|(schema_name, view_name)| {
705                        format!("view {}.{} depends on it", schema_name, view_name)
706                    }));
707                }
708                ObjectType::Subscription => {
709                    let subscriptions: Vec<(String, String)> = Object::find()
710                        .join(JoinType::InnerJoin, object::Relation::Subscription.def())
711                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
712                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
713                        .select_only()
714                        .column(schema::Column::Name)
715                        .column(subscription::Column::Name)
716                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
717                        .into_tuple()
718                        .all(db)
719                        .await?;
720                    details.extend(subscriptions.into_iter().map(
721                        |(schema_name, subscription_name)| {
722                            format!(
723                                "subscription {}.{} depends on it",
724                                schema_name, subscription_name
725                            )
726                        },
727                    ));
728                }
729                ObjectType::Source => {
730                    let sources: Vec<(String, String)> = Object::find()
731                        .join(JoinType::InnerJoin, object::Relation::Source.def())
732                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
733                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
734                        .select_only()
735                        .column(schema::Column::Name)
736                        .column(source::Column::Name)
737                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
738                        .into_tuple()
739                        .all(db)
740                        .await?;
741                    details.extend(sources.into_iter().map(|(schema_name, view_name)| {
742                        format!("source {}.{} depends on it", schema_name, view_name)
743                    }));
744                }
745                ObjectType::Connection => {
746                    let connections: Vec<(String, String)> = Object::find()
747                        .join(JoinType::InnerJoin, object::Relation::Connection.def())
748                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
749                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
750                        .select_only()
751                        .column(schema::Column::Name)
752                        .column(connection::Column::Name)
753                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
754                        .into_tuple()
755                        .all(db)
756                        .await?;
757                    details.extend(connections.into_iter().map(|(schema_name, view_name)| {
758                        format!("connection {}.{} depends on it", schema_name, view_name)
759                    }));
760                }
761                // only the table, source, sink, subscription, view, connection and index will depend on other objects.
762                _ => bail!("unexpected referring object type: {}", obj_type.as_str()),
763            }
764        }
765
766        return Err(MetaError::permission_denied(format!(
767            "{} used by {} other objects. \nDETAIL: {}\n\
768            {}",
769            object_type.as_str(),
770            count,
771            details.join("\n"),
772            match object_type {
773                ObjectType::Function | ObjectType::Connection | ObjectType::Secret =>
774                    "HINT: DROP the dependent objects first.",
775                ObjectType::Database | ObjectType::Schema => unreachable!(),
776                _ => "HINT:  Use DROP ... CASCADE to drop the dependent objects too.",
777            }
778        )));
779    }
780    Ok(())
781}
782
783/// List all objects that are using the given one.
784pub async fn get_referring_objects<C>(object_id: ObjectId, db: &C) -> MetaResult<Vec<PartialObject>>
785where
786    C: ConnectionTrait,
787{
788    let objs = ObjectDependency::find()
789        .filter(object_dependency::Column::Oid.eq(object_id))
790        .join(
791            JoinType::InnerJoin,
792            object_dependency::Relation::Object1.def(),
793        )
794        .into_partial_model()
795        .all(db)
796        .await?;
797
798    Ok(objs)
799}
800
801/// `ensure_schema_empty` ensures that the schema is empty, used by `DROP SCHEMA`.
802pub async fn ensure_schema_empty<C>(schema_id: SchemaId, db: &C) -> MetaResult<()>
803where
804    C: ConnectionTrait,
805{
806    let count = Object::find()
807        .filter(object::Column::SchemaId.eq(Some(schema_id)))
808        .count(db)
809        .await?;
810    if count != 0 {
811        return Err(MetaError::permission_denied("schema is not empty"));
812    }
813
814    Ok(())
815}
816
817/// `list_user_info_by_ids` lists all users' info by their ids.
818pub async fn list_user_info_by_ids<C>(
819    user_ids: impl IntoIterator<Item = UserId>,
820    db: &C,
821) -> MetaResult<Vec<PbUserInfo>>
822where
823    C: ConnectionTrait,
824{
825    let mut user_infos = vec![];
826    for user_id in user_ids {
827        let user = User::find_by_id(user_id)
828            .one(db)
829            .await?
830            .ok_or_else(|| MetaError::catalog_id_not_found("user", user_id))?;
831        let mut user_info: PbUserInfo = user.into();
832        user_info.grant_privileges = get_user_privilege(user_id, db).await?;
833        user_infos.push(user_info);
834    }
835    Ok(user_infos)
836}
837
838/// `get_object_owner` returns the owner of the given object.
839pub async fn get_object_owner<C>(object_id: ObjectId, db: &C) -> MetaResult<UserId>
840where
841    C: ConnectionTrait,
842{
843    let obj_owner: UserId = Object::find_by_id(object_id)
844        .select_only()
845        .column(object::Column::OwnerId)
846        .into_tuple()
847        .one(db)
848        .await?
849        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
850    Ok(obj_owner)
851}
852
853/// `construct_privilege_dependency_query` constructs a query to find all privileges that are dependent on the given one.
854///
855/// # Examples
856///
857/// ```
858/// use risingwave_meta::controller::utils::construct_privilege_dependency_query;
859/// use sea_orm::sea_query::*;
860/// use sea_orm::*;
861///
862/// let query = construct_privilege_dependency_query(vec![1, 2, 3]);
863///
864/// assert_eq!(
865///    query.to_string(MysqlQueryBuilder),
866///   r#"WITH RECURSIVE `granted_privilege_ids` (`id`, `user_id`) AS (SELECT `id`, `user_id` FROM `user_privilege` WHERE `user_privilege`.`id` IN (1, 2, 3) UNION ALL (SELECT `user_privilege`.`id`, `user_privilege`.`user_id` FROM `user_privilege` INNER JOIN `granted_privilege_ids` ON `granted_privilege_ids`.`id` = `dependent_id`)) SELECT `id`, `user_id` FROM `granted_privilege_ids`"#
867/// );
868/// assert_eq!(
869///   query.to_string(PostgresQueryBuilder),
870///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL (SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id")) SELECT "id", "user_id" FROM "granted_privilege_ids""#
871/// );
872/// assert_eq!(
873///  query.to_string(SqliteQueryBuilder),
874///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id") SELECT "id", "user_id" FROM "granted_privilege_ids""#
875/// );
876/// ```
877pub fn construct_privilege_dependency_query(ids: Vec<PrivilegeId>) -> WithQuery {
878    let cte_alias = Alias::new("granted_privilege_ids");
879    let cte_return_privilege_alias = Alias::new("id");
880    let cte_return_user_alias = Alias::new("user_id");
881
882    let mut base_query = SelectStatement::new()
883        .columns([user_privilege::Column::Id, user_privilege::Column::UserId])
884        .from(UserPrivilege)
885        .and_where(user_privilege::Column::Id.is_in(ids))
886        .to_owned();
887
888    let cte_referencing = Query::select()
889        .columns([
890            (UserPrivilege, user_privilege::Column::Id),
891            (UserPrivilege, user_privilege::Column::UserId),
892        ])
893        .from(UserPrivilege)
894        .inner_join(
895            cte_alias.clone(),
896            Expr::col((cte_alias.clone(), cte_return_privilege_alias.clone()))
897                .equals(user_privilege::Column::DependentId),
898        )
899        .to_owned();
900
901    let common_table_expr = CommonTableExpression::new()
902        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
903        .columns([
904            cte_return_privilege_alias.clone(),
905            cte_return_user_alias.clone(),
906        ])
907        .table_name(cte_alias.clone())
908        .to_owned();
909
910    SelectStatement::new()
911        .columns([cte_return_privilege_alias, cte_return_user_alias])
912        .from(cte_alias.clone())
913        .to_owned()
914        .with(
915            WithClause::new()
916                .recursive(true)
917                .cte(common_table_expr)
918                .to_owned(),
919        )
920        .to_owned()
921}
922
923pub async fn get_internal_tables_by_id<C>(job_id: ObjectId, db: &C) -> MetaResult<Vec<TableId>>
924where
925    C: ConnectionTrait,
926{
927    let table_ids: Vec<TableId> = Table::find()
928        .select_only()
929        .column(table::Column::TableId)
930        .filter(
931            table::Column::TableType
932                .eq(TableType::Internal)
933                .and(table::Column::BelongsToJobId.eq(job_id)),
934        )
935        .into_tuple()
936        .all(db)
937        .await?;
938    Ok(table_ids)
939}
940
941pub async fn get_index_state_tables_by_table_id<C>(
942    table_id: TableId,
943    db: &C,
944) -> MetaResult<Vec<TableId>>
945where
946    C: ConnectionTrait,
947{
948    let mut index_table_ids: Vec<TableId> = Index::find()
949        .select_only()
950        .column(index::Column::IndexTableId)
951        .filter(index::Column::PrimaryTableId.eq(table_id))
952        .into_tuple()
953        .all(db)
954        .await?;
955
956    if !index_table_ids.is_empty() {
957        let internal_table_ids: Vec<TableId> = Table::find()
958            .select_only()
959            .column(table::Column::TableId)
960            .filter(
961                table::Column::TableType
962                    .eq(TableType::Internal)
963                    .and(table::Column::BelongsToJobId.is_in(index_table_ids.clone())),
964            )
965            .into_tuple()
966            .all(db)
967            .await?;
968
969        index_table_ids.extend(internal_table_ids.into_iter());
970    }
971
972    Ok(index_table_ids)
973}
974
975#[derive(Clone, DerivePartialModel, FromQueryResult)]
976#[sea_orm(entity = "UserPrivilege")]
977pub struct PartialUserPrivilege {
978    pub id: PrivilegeId,
979    pub user_id: UserId,
980}
981
982pub async fn get_referring_privileges_cascade<C>(
983    ids: Vec<PrivilegeId>,
984    db: &C,
985) -> MetaResult<Vec<PartialUserPrivilege>>
986where
987    C: ConnectionTrait,
988{
989    let query = construct_privilege_dependency_query(ids);
990    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
991    let privileges = PartialUserPrivilege::find_by_statement(Statement::from_sql_and_values(
992        db.get_database_backend(),
993        sql,
994        values,
995    ))
996    .all(db)
997    .await?;
998
999    Ok(privileges)
1000}
1001
1002/// `ensure_privileges_not_referred` ensures that the privileges are not granted to any other users.
1003pub async fn ensure_privileges_not_referred<C>(ids: Vec<PrivilegeId>, db: &C) -> MetaResult<()>
1004where
1005    C: ConnectionTrait,
1006{
1007    let count = UserPrivilege::find()
1008        .filter(user_privilege::Column::DependentId.is_in(ids))
1009        .count(db)
1010        .await?;
1011    if count != 0 {
1012        return Err(MetaError::permission_denied(format!(
1013            "privileges granted to {} other ones.",
1014            count
1015        )));
1016    }
1017    Ok(())
1018}
1019
1020/// `get_user_privilege` returns the privileges of the given user.
1021pub async fn get_user_privilege<C>(user_id: UserId, db: &C) -> MetaResult<Vec<PbGrantPrivilege>>
1022where
1023    C: ConnectionTrait,
1024{
1025    let user_privileges = UserPrivilege::find()
1026        .find_also_related(Object)
1027        .filter(user_privilege::Column::UserId.eq(user_id))
1028        .all(db)
1029        .await?;
1030    Ok(user_privileges
1031        .into_iter()
1032        .map(|(privilege, object)| {
1033            let object = object.unwrap();
1034            let oid = object.oid as _;
1035            let obj = match object.obj_type {
1036                ObjectType::Database => PbGrantObject::DatabaseId(oid),
1037                ObjectType::Schema => PbGrantObject::SchemaId(oid),
1038                ObjectType::Table | ObjectType::Index => PbGrantObject::TableId(oid),
1039                ObjectType::Source => PbGrantObject::SourceId(oid),
1040                ObjectType::Sink => PbGrantObject::SinkId(oid),
1041                ObjectType::View => PbGrantObject::ViewId(oid),
1042                ObjectType::Function => PbGrantObject::FunctionId(oid),
1043                ObjectType::Connection => PbGrantObject::ConnectionId(oid),
1044                ObjectType::Subscription => PbGrantObject::SubscriptionId(oid),
1045                ObjectType::Secret => PbGrantObject::SecretId(oid),
1046            };
1047            PbGrantPrivilege {
1048                action_with_opts: vec![PbActionWithGrantOption {
1049                    action: PbAction::from(privilege.action) as _,
1050                    with_grant_option: privilege.with_grant_option,
1051                    granted_by: privilege.granted_by as _,
1052                }],
1053                object: Some(obj),
1054            }
1055        })
1056        .collect())
1057}
1058
1059/// `grant_default_privileges_automatically` grants default privileges automatically
1060/// for the given new object. It returns the list of user infos whose privileges are updated.
1061pub async fn grant_default_privileges_automatically<C>(
1062    db: &C,
1063    object_id: ObjectId,
1064) -> MetaResult<Vec<PbUserInfo>>
1065where
1066    C: ConnectionTrait,
1067{
1068    let object = Object::find_by_id(object_id)
1069        .one(db)
1070        .await?
1071        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
1072    assert_ne!(object.obj_type, ObjectType::Database);
1073
1074    let for_mview_filter = if object.obj_type == ObjectType::Table {
1075        let table_type = Table::find_by_id(object_id)
1076            .select_only()
1077            .column(table::Column::TableType)
1078            .into_tuple::<TableType>()
1079            .one(db)
1080            .await?
1081            .ok_or_else(|| MetaError::catalog_id_not_found("table", object_id))?;
1082        user_default_privilege::Column::ForMaterializedView
1083            .eq(table_type == TableType::MaterializedView)
1084    } else {
1085        user_default_privilege::Column::ForMaterializedView.eq(false)
1086    };
1087    let schema_filter = if let Some(schema_id) = &object.schema_id {
1088        user_default_privilege::Column::SchemaId.eq(*schema_id)
1089    } else {
1090        user_default_privilege::Column::SchemaId.is_null()
1091    };
1092
1093    let default_privileges: Vec<(UserId, UserId, Action, bool)> = UserDefaultPrivilege::find()
1094        .select_only()
1095        .columns([
1096            user_default_privilege::Column::Grantee,
1097            user_default_privilege::Column::GrantedBy,
1098            user_default_privilege::Column::Action,
1099            user_default_privilege::Column::WithGrantOption,
1100        ])
1101        .filter(
1102            user_default_privilege::Column::DatabaseId
1103                .eq(object.database_id.unwrap())
1104                .and(schema_filter)
1105                .and(user_default_privilege::Column::UserId.eq(object.owner_id))
1106                .and(user_default_privilege::Column::ObjectType.eq(object.obj_type))
1107                .and(for_mview_filter),
1108        )
1109        .into_tuple()
1110        .all(db)
1111        .await?;
1112    if default_privileges.is_empty() {
1113        return Ok(vec![]);
1114    }
1115
1116    let updated_user_ids = default_privileges
1117        .iter()
1118        .map(|(grantee, _, _, _)| *grantee)
1119        .collect::<HashSet<_>>();
1120
1121    let internal_table_ids = get_internal_tables_by_id(object_id, db).await?;
1122
1123    for (grantee, granted_by, action, with_grant_option) in default_privileges {
1124        UserPrivilege::insert(user_privilege::ActiveModel {
1125            user_id: Set(grantee),
1126            oid: Set(object_id),
1127            granted_by: Set(granted_by),
1128            action: Set(action),
1129            with_grant_option: Set(with_grant_option),
1130            ..Default::default()
1131        })
1132        .exec(db)
1133        .await?;
1134        if action == Action::Select && !internal_table_ids.is_empty() {
1135            // Grant SELECT privilege for internal tables if the action is SELECT.
1136            for internal_table_id in &internal_table_ids {
1137                UserPrivilege::insert(user_privilege::ActiveModel {
1138                    user_id: Set(grantee),
1139                    oid: Set(*internal_table_id as _),
1140                    granted_by: Set(granted_by),
1141                    action: Set(Action::Select),
1142                    with_grant_option: Set(with_grant_option),
1143                    ..Default::default()
1144                })
1145                .exec(db)
1146                .await?;
1147            }
1148        }
1149    }
1150
1151    let updated_user_infos = list_user_info_by_ids(updated_user_ids, db).await?;
1152    Ok(updated_user_infos)
1153}
1154
1155// todo: remove it after migrated to sql backend.
1156pub fn extract_grant_obj_id(object: &PbGrantObject) -> ObjectId {
1157    match object {
1158        PbGrantObject::DatabaseId(id)
1159        | PbGrantObject::SchemaId(id)
1160        | PbGrantObject::TableId(id)
1161        | PbGrantObject::SourceId(id)
1162        | PbGrantObject::SinkId(id)
1163        | PbGrantObject::ViewId(id)
1164        | PbGrantObject::FunctionId(id)
1165        | PbGrantObject::SubscriptionId(id)
1166        | PbGrantObject::ConnectionId(id)
1167        | PbGrantObject::SecretId(id) => *id as _,
1168    }
1169}
1170
1171pub async fn insert_fragment_relations(
1172    db: &impl ConnectionTrait,
1173    downstream_fragment_relations: &FragmentDownstreamRelation,
1174) -> MetaResult<()> {
1175    for (upstream_fragment_id, downstreams) in downstream_fragment_relations {
1176        for downstream in downstreams {
1177            let relation = fragment_relation::Model {
1178                source_fragment_id: *upstream_fragment_id as _,
1179                target_fragment_id: downstream.downstream_fragment_id as _,
1180                dispatcher_type: downstream.dispatcher_type,
1181                dist_key_indices: downstream
1182                    .dist_key_indices
1183                    .iter()
1184                    .map(|idx| *idx as i32)
1185                    .collect_vec()
1186                    .into(),
1187                output_indices: downstream
1188                    .output_mapping
1189                    .indices
1190                    .iter()
1191                    .map(|idx| *idx as i32)
1192                    .collect_vec()
1193                    .into(),
1194                output_type_mapping: Some(downstream.output_mapping.types.clone().into()),
1195            };
1196            FragmentRelation::insert(relation.into_active_model())
1197                .exec(db)
1198                .await?;
1199        }
1200    }
1201    Ok(())
1202}
1203
1204pub async fn get_fragment_actor_dispatchers<C>(
1205    db: &C,
1206    fragment_ids: Vec<FragmentId>,
1207) -> MetaResult<FragmentActorDispatchers>
1208where
1209    C: ConnectionTrait,
1210{
1211    type FragmentActorInfo = (
1212        DistributionType,
1213        Arc<HashMap<crate::model::ActorId, Option<Bitmap>>>,
1214    );
1215    let mut fragment_actor_cache: HashMap<FragmentId, FragmentActorInfo> = HashMap::new();
1216    let get_fragment_actors = |fragment_id: FragmentId| async move {
1217        let result: MetaResult<FragmentActorInfo> = try {
1218            let mut fragment_actors = Fragment::find_by_id(fragment_id)
1219                .find_with_related(Actor)
1220                .filter(actor::Column::Status.eq(ActorStatus::Running))
1221                .all(db)
1222                .await?;
1223            if fragment_actors.is_empty() {
1224                return Err(anyhow!("failed to find fragment: {}", fragment_id).into());
1225            }
1226            assert_eq!(
1227                fragment_actors.len(),
1228                1,
1229                "find multiple fragment {:?}",
1230                fragment_actors
1231            );
1232            let (fragment, actors) = fragment_actors.pop().unwrap();
1233            (
1234                fragment.distribution_type,
1235                Arc::new(
1236                    actors
1237                        .into_iter()
1238                        .map(|actor| {
1239                            (
1240                                actor.actor_id as _,
1241                                actor
1242                                    .vnode_bitmap
1243                                    .map(|bitmap| Bitmap::from(bitmap.to_protobuf())),
1244                            )
1245                        })
1246                        .collect(),
1247                ),
1248            )
1249        };
1250        result
1251    };
1252    let fragment_relations = FragmentRelation::find()
1253        .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids))
1254        .all(db)
1255        .await?;
1256
1257    let mut actor_dispatchers_map: HashMap<_, HashMap<_, Vec<_>>> = HashMap::new();
1258    for fragment_relation::Model {
1259        source_fragment_id,
1260        target_fragment_id,
1261        dispatcher_type,
1262        dist_key_indices,
1263        output_indices,
1264        output_type_mapping,
1265    } in fragment_relations
1266    {
1267        let (source_fragment_distribution, source_fragment_actors) = {
1268            let (distribution, actors) = {
1269                match fragment_actor_cache.entry(source_fragment_id) {
1270                    Entry::Occupied(entry) => entry.into_mut(),
1271                    Entry::Vacant(entry) => {
1272                        entry.insert(get_fragment_actors(source_fragment_id).await?)
1273                    }
1274                }
1275            };
1276            (*distribution, actors.clone())
1277        };
1278        let (target_fragment_distribution, target_fragment_actors) = {
1279            let (distribution, actors) = {
1280                match fragment_actor_cache.entry(target_fragment_id) {
1281                    Entry::Occupied(entry) => entry.into_mut(),
1282                    Entry::Vacant(entry) => {
1283                        entry.insert(get_fragment_actors(target_fragment_id).await?)
1284                    }
1285                }
1286            };
1287            (*distribution, actors.clone())
1288        };
1289        let output_mapping = PbDispatchOutputMapping {
1290            indices: output_indices.into_u32_array(),
1291            types: output_type_mapping.unwrap_or_default().to_protobuf(),
1292        };
1293        let dispatchers = compose_dispatchers(
1294            source_fragment_distribution,
1295            &source_fragment_actors,
1296            target_fragment_id as _,
1297            target_fragment_distribution,
1298            &target_fragment_actors,
1299            dispatcher_type,
1300            dist_key_indices.into_u32_array(),
1301            output_mapping,
1302        );
1303        let actor_dispatchers_map = actor_dispatchers_map
1304            .entry(source_fragment_id as _)
1305            .or_default();
1306        for (actor_id, dispatchers) in dispatchers {
1307            actor_dispatchers_map
1308                .entry(actor_id as _)
1309                .or_default()
1310                .push(dispatchers);
1311        }
1312    }
1313    Ok(actor_dispatchers_map)
1314}
1315
1316pub fn compose_dispatchers(
1317    source_fragment_distribution: DistributionType,
1318    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1319    target_fragment_id: crate::model::FragmentId,
1320    target_fragment_distribution: DistributionType,
1321    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1322    dispatcher_type: DispatcherType,
1323    dist_key_indices: Vec<u32>,
1324    output_mapping: PbDispatchOutputMapping,
1325) -> HashMap<crate::model::ActorId, PbDispatcher> {
1326    match dispatcher_type {
1327        DispatcherType::Hash => {
1328            let dispatcher = PbDispatcher {
1329                r#type: PbDispatcherType::from(dispatcher_type) as _,
1330                dist_key_indices: dist_key_indices.clone(),
1331                output_mapping: output_mapping.into(),
1332                hash_mapping: Some(
1333                    ActorMapping::from_bitmaps(
1334                        &target_fragment_actors
1335                            .iter()
1336                            .map(|(actor_id, bitmap)| {
1337                                (
1338                                    *actor_id as _,
1339                                    bitmap
1340                                        .clone()
1341                                        .expect("downstream hash dispatch must have distribution"),
1342                                )
1343                            })
1344                            .collect(),
1345                    )
1346                    .to_protobuf(),
1347                ),
1348                dispatcher_id: target_fragment_id as _,
1349                downstream_actor_id: target_fragment_actors
1350                    .keys()
1351                    .map(|actor_id| *actor_id as _)
1352                    .collect(),
1353            };
1354            source_fragment_actors
1355                .keys()
1356                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1357                .collect()
1358        }
1359        DispatcherType::Broadcast | DispatcherType::Simple => {
1360            let dispatcher = PbDispatcher {
1361                r#type: PbDispatcherType::from(dispatcher_type) as _,
1362                dist_key_indices: dist_key_indices.clone(),
1363                output_mapping: output_mapping.into(),
1364                hash_mapping: None,
1365                dispatcher_id: target_fragment_id as _,
1366                downstream_actor_id: target_fragment_actors
1367                    .keys()
1368                    .map(|actor_id| *actor_id as _)
1369                    .collect(),
1370            };
1371            source_fragment_actors
1372                .keys()
1373                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1374                .collect()
1375        }
1376        DispatcherType::NoShuffle => resolve_no_shuffle_actor_dispatcher(
1377            source_fragment_distribution,
1378            source_fragment_actors,
1379            target_fragment_distribution,
1380            target_fragment_actors,
1381        )
1382        .into_iter()
1383        .map(|(upstream_actor_id, downstream_actor_id)| {
1384            (
1385                upstream_actor_id,
1386                PbDispatcher {
1387                    r#type: PbDispatcherType::NoShuffle as _,
1388                    dist_key_indices: dist_key_indices.clone(),
1389                    output_mapping: output_mapping.clone().into(),
1390                    hash_mapping: None,
1391                    dispatcher_id: target_fragment_id as _,
1392                    downstream_actor_id: vec![downstream_actor_id as _],
1393                },
1394            )
1395        })
1396        .collect(),
1397    }
1398}
1399
1400/// return (`upstream_actor_id` -> `downstream_actor_id`)
1401pub fn resolve_no_shuffle_actor_dispatcher(
1402    source_fragment_distribution: DistributionType,
1403    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1404    target_fragment_distribution: DistributionType,
1405    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1406) -> Vec<(crate::model::ActorId, crate::model::ActorId)> {
1407    assert_eq!(source_fragment_distribution, target_fragment_distribution);
1408    assert_eq!(
1409        source_fragment_actors.len(),
1410        target_fragment_actors.len(),
1411        "no-shuffle should have equal upstream downstream actor count: {:?} {:?}",
1412        source_fragment_actors,
1413        target_fragment_actors
1414    );
1415    match source_fragment_distribution {
1416        DistributionType::Single => {
1417            let assert_singleton = |bitmap: &Option<Bitmap>| {
1418                assert!(
1419                    bitmap.as_ref().map(|bitmap| bitmap.all()).unwrap_or(true),
1420                    "not singleton: {:?}",
1421                    bitmap
1422                );
1423            };
1424            assert_eq!(
1425                source_fragment_actors.len(),
1426                1,
1427                "singleton distribution actor count not 1: {:?}",
1428                source_fragment_distribution
1429            );
1430            assert_eq!(
1431                target_fragment_actors.len(),
1432                1,
1433                "singleton distribution actor count not 1: {:?}",
1434                target_fragment_distribution
1435            );
1436            let (source_actor_id, bitmap) = source_fragment_actors.iter().next().unwrap();
1437            assert_singleton(bitmap);
1438            let (target_actor_id, bitmap) = target_fragment_actors.iter().next().unwrap();
1439            assert_singleton(bitmap);
1440            vec![(*source_actor_id, *target_actor_id)]
1441        }
1442        DistributionType::Hash => {
1443            let mut target_fragment_actor_index: HashMap<_, _> = target_fragment_actors
1444                .iter()
1445                .map(|(actor_id, bitmap)| {
1446                    let bitmap = bitmap
1447                        .as_ref()
1448                        .expect("hash distribution should have bitmap");
1449                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1450                    (first_vnode, (*actor_id, bitmap))
1451                })
1452                .collect();
1453            source_fragment_actors
1454                .iter()
1455                .map(|(source_actor_id, bitmap)| {
1456                    let bitmap = bitmap
1457                        .as_ref()
1458                        .expect("hash distribution should have bitmap");
1459                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1460                    let (target_actor_id, target_bitmap) =
1461                        target_fragment_actor_index.remove(&first_vnode).unwrap_or_else(|| {
1462                            panic!(
1463                                "cannot find matched target actor: {} {:?} {:?} {:?}",
1464                                source_actor_id,
1465                                first_vnode,
1466                                source_fragment_actors,
1467                                target_fragment_actors
1468                            );
1469                        });
1470                    assert_eq!(
1471                        bitmap,
1472                        target_bitmap,
1473                        "cannot find matched target actor due to bitmap mismatch: {} {:?} {:?} {:?}",
1474                        source_actor_id,
1475                        first_vnode,
1476                        source_fragment_actors,
1477                        target_fragment_actors
1478                    );
1479                    (*source_actor_id, target_actor_id)
1480                }).collect()
1481        }
1482    }
1483}
1484
1485/// `get_fragment_mappings` returns the fragment vnode mappings of the given job.
1486pub async fn get_fragment_mappings<C>(
1487    db: &C,
1488    job_id: ObjectId,
1489) -> MetaResult<Vec<PbFragmentWorkerSlotMapping>>
1490where
1491    C: ConnectionTrait,
1492{
1493    let job_actors: Vec<(
1494        FragmentId,
1495        DistributionType,
1496        ActorId,
1497        Option<VnodeBitmap>,
1498        WorkerId,
1499        ActorStatus,
1500    )> = Actor::find()
1501        .select_only()
1502        .columns([
1503            fragment::Column::FragmentId,
1504            fragment::Column::DistributionType,
1505        ])
1506        .columns([
1507            actor::Column::ActorId,
1508            actor::Column::VnodeBitmap,
1509            actor::Column::WorkerId,
1510            actor::Column::Status,
1511        ])
1512        .join(JoinType::InnerJoin, actor::Relation::Fragment.def())
1513        .filter(fragment::Column::JobId.eq(job_id))
1514        .into_tuple()
1515        .all(db)
1516        .await?;
1517
1518    Ok(rebuild_fragment_mapping_from_actors(job_actors))
1519}
1520
1521pub fn rebuild_fragment_mapping_from_actors(
1522    job_actors: Vec<(
1523        FragmentId,
1524        DistributionType,
1525        ActorId,
1526        Option<VnodeBitmap>,
1527        WorkerId,
1528        ActorStatus,
1529    )>,
1530) -> Vec<FragmentWorkerSlotMapping> {
1531    let mut all_actor_locations = HashMap::new();
1532    let mut actor_bitmaps = HashMap::new();
1533    let mut fragment_actors = HashMap::new();
1534    let mut fragment_dist = HashMap::new();
1535
1536    for (fragment_id, dist, actor_id, bitmap, worker_id, actor_status) in job_actors {
1537        if actor_status == ActorStatus::Inactive {
1538            continue;
1539        }
1540
1541        all_actor_locations
1542            .entry(fragment_id)
1543            .or_insert(HashMap::new())
1544            .insert(actor_id as hash::ActorId, worker_id as u32);
1545        actor_bitmaps.insert(actor_id, bitmap);
1546        fragment_actors
1547            .entry(fragment_id)
1548            .or_insert_with(Vec::new)
1549            .push(actor_id);
1550        fragment_dist.insert(fragment_id, dist);
1551    }
1552
1553    let mut result = vec![];
1554    for (fragment_id, dist) in fragment_dist {
1555        let mut actor_locations = all_actor_locations.remove(&fragment_id).unwrap();
1556        let fragment_worker_slot_mapping = match dist {
1557            DistributionType::Single => {
1558                let actor = fragment_actors
1559                    .remove(&fragment_id)
1560                    .unwrap()
1561                    .into_iter()
1562                    .exactly_one()
1563                    .unwrap() as hash::ActorId;
1564                let actor_location = actor_locations.remove(&actor).unwrap();
1565
1566                WorkerSlotMapping::new_single(WorkerSlotId::new(actor_location, 0))
1567            }
1568            DistributionType::Hash => {
1569                let actors = fragment_actors.remove(&fragment_id).unwrap();
1570
1571                let all_actor_bitmaps: HashMap<_, _> = actors
1572                    .iter()
1573                    .map(|actor_id| {
1574                        let vnode_bitmap = actor_bitmaps
1575                            .remove(actor_id)
1576                            .flatten()
1577                            .expect("actor bitmap shouldn't be none in hash fragment");
1578
1579                        let bitmap = Bitmap::from(&vnode_bitmap.to_protobuf());
1580                        (*actor_id as hash::ActorId, bitmap)
1581                    })
1582                    .collect();
1583
1584                let actor_mapping = ActorMapping::from_bitmaps(&all_actor_bitmaps);
1585
1586                actor_mapping.to_worker_slot(&actor_locations)
1587            }
1588        };
1589
1590        result.push(PbFragmentWorkerSlotMapping {
1591            fragment_id: fragment_id as u32,
1592            mapping: Some(fragment_worker_slot_mapping.to_protobuf()),
1593        })
1594    }
1595    result
1596}
1597
1598/// `get_fragment_actor_ids` returns the fragment actor ids of the given fragments.
1599pub async fn get_fragment_actor_ids<C>(
1600    db: &C,
1601    fragment_ids: Vec<FragmentId>,
1602) -> MetaResult<HashMap<FragmentId, Vec<ActorId>>>
1603where
1604    C: ConnectionTrait,
1605{
1606    let fragment_actors: Vec<(FragmentId, ActorId)> = Actor::find()
1607        .select_only()
1608        .columns([actor::Column::FragmentId, actor::Column::ActorId])
1609        .filter(actor::Column::FragmentId.is_in(fragment_ids))
1610        .into_tuple()
1611        .all(db)
1612        .await?;
1613
1614    Ok(fragment_actors.into_iter().into_group_map())
1615}
1616
1617/// For the given streaming jobs, returns
1618/// - All source fragments
1619/// - All actors
1620/// - All fragments
1621pub async fn get_fragments_for_jobs<C>(
1622    db: &C,
1623    streaming_jobs: Vec<ObjectId>,
1624) -> MetaResult<(
1625    HashMap<SourceId, BTreeSet<FragmentId>>,
1626    HashSet<ActorId>,
1627    HashSet<FragmentId>,
1628)>
1629where
1630    C: ConnectionTrait,
1631{
1632    if streaming_jobs.is_empty() {
1633        return Ok((HashMap::default(), HashSet::default(), HashSet::default()));
1634    }
1635
1636    let fragments: Vec<(FragmentId, i32, StreamNode)> = Fragment::find()
1637        .select_only()
1638        .columns([
1639            fragment::Column::FragmentId,
1640            fragment::Column::FragmentTypeMask,
1641            fragment::Column::StreamNode,
1642        ])
1643        .filter(fragment::Column::JobId.is_in(streaming_jobs))
1644        .into_tuple()
1645        .all(db)
1646        .await?;
1647    let actors: Vec<ActorId> = Actor::find()
1648        .select_only()
1649        .column(actor::Column::ActorId)
1650        .filter(
1651            actor::Column::FragmentId.is_in(fragments.iter().map(|(id, _, _)| *id).collect_vec()),
1652        )
1653        .into_tuple()
1654        .all(db)
1655        .await?;
1656
1657    let fragment_ids = fragments
1658        .iter()
1659        .map(|(fragment_id, _, _)| *fragment_id)
1660        .collect();
1661
1662    let mut source_fragment_ids: HashMap<SourceId, BTreeSet<FragmentId>> = HashMap::new();
1663    for (fragment_id, mask, stream_node) in fragments {
1664        if !FragmentTypeMask::from(mask).contains(FragmentTypeFlag::Source) {
1665            continue;
1666        }
1667        if let Some(source_id) = stream_node.to_protobuf().find_stream_source() {
1668            source_fragment_ids
1669                .entry(source_id as _)
1670                .or_default()
1671                .insert(fragment_id);
1672        }
1673    }
1674
1675    Ok((
1676        source_fragment_ids,
1677        actors.into_iter().collect(),
1678        fragment_ids,
1679    ))
1680}
1681
1682/// Build a object group for notifying the deletion of the given objects.
1683///
1684/// Note that only id fields are filled in the object info, as the arguments are partial objects.
1685/// As a result, the returned notification info should only be used for deletion.
1686pub(crate) fn build_object_group_for_delete(
1687    partial_objects: Vec<PartialObject>,
1688) -> NotificationInfo {
1689    let mut objects = vec![];
1690    for obj in partial_objects {
1691        match obj.obj_type {
1692            ObjectType::Database => objects.push(PbObject {
1693                object_info: Some(PbObjectInfo::Database(PbDatabase {
1694                    id: obj.oid as _,
1695                    ..Default::default()
1696                })),
1697            }),
1698            ObjectType::Schema => objects.push(PbObject {
1699                object_info: Some(PbObjectInfo::Schema(PbSchema {
1700                    id: obj.oid as _,
1701                    database_id: obj.database_id.unwrap() as _,
1702                    ..Default::default()
1703                })),
1704            }),
1705            ObjectType::Table => objects.push(PbObject {
1706                object_info: Some(PbObjectInfo::Table(PbTable {
1707                    id: obj.oid as _,
1708                    schema_id: obj.schema_id.unwrap() as _,
1709                    database_id: obj.database_id.unwrap() as _,
1710                    ..Default::default()
1711                })),
1712            }),
1713            ObjectType::Source => objects.push(PbObject {
1714                object_info: Some(PbObjectInfo::Source(PbSource {
1715                    id: obj.oid as _,
1716                    schema_id: obj.schema_id.unwrap() as _,
1717                    database_id: obj.database_id.unwrap() as _,
1718                    ..Default::default()
1719                })),
1720            }),
1721            ObjectType::Sink => objects.push(PbObject {
1722                object_info: Some(PbObjectInfo::Sink(PbSink {
1723                    id: obj.oid as _,
1724                    schema_id: obj.schema_id.unwrap() as _,
1725                    database_id: obj.database_id.unwrap() as _,
1726                    ..Default::default()
1727                })),
1728            }),
1729            ObjectType::Subscription => objects.push(PbObject {
1730                object_info: Some(PbObjectInfo::Subscription(PbSubscription {
1731                    id: obj.oid as _,
1732                    schema_id: obj.schema_id.unwrap() as _,
1733                    database_id: obj.database_id.unwrap() as _,
1734                    ..Default::default()
1735                })),
1736            }),
1737            ObjectType::View => objects.push(PbObject {
1738                object_info: Some(PbObjectInfo::View(PbView {
1739                    id: obj.oid as _,
1740                    schema_id: obj.schema_id.unwrap() as _,
1741                    database_id: obj.database_id.unwrap() as _,
1742                    ..Default::default()
1743                })),
1744            }),
1745            ObjectType::Index => {
1746                objects.push(PbObject {
1747                    object_info: Some(PbObjectInfo::Index(PbIndex {
1748                        id: obj.oid as _,
1749                        schema_id: obj.schema_id.unwrap() as _,
1750                        database_id: obj.database_id.unwrap() as _,
1751                        ..Default::default()
1752                    })),
1753                });
1754                objects.push(PbObject {
1755                    object_info: Some(PbObjectInfo::Table(PbTable {
1756                        id: obj.oid as _,
1757                        schema_id: obj.schema_id.unwrap() as _,
1758                        database_id: obj.database_id.unwrap() as _,
1759                        ..Default::default()
1760                    })),
1761                });
1762            }
1763            ObjectType::Function => objects.push(PbObject {
1764                object_info: Some(PbObjectInfo::Function(PbFunction {
1765                    id: obj.oid as _,
1766                    schema_id: obj.schema_id.unwrap() as _,
1767                    database_id: obj.database_id.unwrap() as _,
1768                    ..Default::default()
1769                })),
1770            }),
1771            ObjectType::Connection => objects.push(PbObject {
1772                object_info: Some(PbObjectInfo::Connection(PbConnection {
1773                    id: obj.oid as _,
1774                    schema_id: obj.schema_id.unwrap() as _,
1775                    database_id: obj.database_id.unwrap() as _,
1776                    ..Default::default()
1777                })),
1778            }),
1779            ObjectType::Secret => objects.push(PbObject {
1780                object_info: Some(PbObjectInfo::Secret(PbSecret {
1781                    id: obj.oid as _,
1782                    schema_id: obj.schema_id.unwrap() as _,
1783                    database_id: obj.database_id.unwrap() as _,
1784                    ..Default::default()
1785                })),
1786            }),
1787        }
1788    }
1789    NotificationInfo::ObjectGroup(PbObjectGroup { objects })
1790}
1791
1792pub fn extract_external_table_name_from_definition(table_definition: &str) -> Option<String> {
1793    let [mut definition]: [_; 1] = Parser::parse_sql(table_definition)
1794        .context("unable to parse table definition")
1795        .inspect_err(|e| {
1796            tracing::error!(
1797                target: "auto_schema_change",
1798                error = %e.as_report(),
1799                "failed to parse table definition")
1800        })
1801        .unwrap()
1802        .try_into()
1803        .unwrap();
1804    if let SqlStatement::CreateTable { cdc_table_info, .. } = &mut definition {
1805        cdc_table_info
1806            .clone()
1807            .map(|cdc_table_info| cdc_table_info.external_table_name)
1808    } else {
1809        None
1810    }
1811}
1812
1813/// `rename_relation` renames the target relation and its definition,
1814/// it commits the changes to the transaction and returns the updated relations and the old name.
1815pub async fn rename_relation(
1816    txn: &DatabaseTransaction,
1817    object_type: ObjectType,
1818    object_id: ObjectId,
1819    object_name: &str,
1820) -> MetaResult<(Vec<PbObject>, String)> {
1821    use sea_orm::ActiveModelTrait;
1822
1823    use crate::controller::rename::alter_relation_rename;
1824
1825    let mut to_update_relations = vec![];
1826    // rename relation.
1827    macro_rules! rename_relation {
1828        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
1829            let (mut relation, obj) = $entity::find_by_id($object_id)
1830                .find_also_related(Object)
1831                .one(txn)
1832                .await?
1833                .unwrap();
1834            let obj = obj.unwrap();
1835            let old_name = relation.name.clone();
1836            relation.name = object_name.into();
1837            if obj.obj_type != ObjectType::View {
1838                relation.definition = alter_relation_rename(&relation.definition, object_name);
1839            }
1840            let active_model = $table::ActiveModel {
1841                $identity: Set(relation.$identity),
1842                name: Set(object_name.into()),
1843                definition: Set(relation.definition.clone()),
1844                ..Default::default()
1845            };
1846            active_model.update(txn).await?;
1847            to_update_relations.push(PbObject {
1848                object_info: Some(PbObjectInfo::$entity(ObjectModel(relation, obj).into())),
1849            });
1850            old_name
1851        }};
1852    }
1853    // TODO: check is there any thing to change for shared source?
1854    let old_name = match object_type {
1855        ObjectType::Table => {
1856            let associated_source_id: Option<SourceId> = Source::find()
1857                .select_only()
1858                .column(source::Column::SourceId)
1859                .filter(source::Column::OptionalAssociatedTableId.eq(object_id))
1860                .into_tuple()
1861                .one(txn)
1862                .await?;
1863            if let Some(source_id) = associated_source_id {
1864                rename_relation!(Source, source, source_id, source_id);
1865            }
1866            rename_relation!(Table, table, table_id, object_id)
1867        }
1868        ObjectType::Source => rename_relation!(Source, source, source_id, object_id),
1869        ObjectType::Sink => rename_relation!(Sink, sink, sink_id, object_id),
1870        ObjectType::Subscription => {
1871            rename_relation!(Subscription, subscription, subscription_id, object_id)
1872        }
1873        ObjectType::View => rename_relation!(View, view, view_id, object_id),
1874        ObjectType::Index => {
1875            let (mut index, obj) = Index::find_by_id(object_id)
1876                .find_also_related(Object)
1877                .one(txn)
1878                .await?
1879                .unwrap();
1880            index.name = object_name.into();
1881            let index_table_id = index.index_table_id;
1882            let old_name = rename_relation!(Table, table, table_id, index_table_id);
1883
1884            // the name of index and its associated table is the same.
1885            let active_model = index::ActiveModel {
1886                index_id: sea_orm::ActiveValue::Set(index.index_id),
1887                name: sea_orm::ActiveValue::Set(object_name.into()),
1888                ..Default::default()
1889            };
1890            active_model.update(txn).await?;
1891            to_update_relations.push(PbObject {
1892                object_info: Some(PbObjectInfo::Index(ObjectModel(index, obj.unwrap()).into())),
1893            });
1894            old_name
1895        }
1896        _ => unreachable!("only relation name can be altered."),
1897    };
1898
1899    Ok((to_update_relations, old_name))
1900}
1901
1902pub async fn get_database_resource_group<C>(txn: &C, database_id: ObjectId) -> MetaResult<String>
1903where
1904    C: ConnectionTrait,
1905{
1906    let database_resource_group: Option<String> = Database::find_by_id(database_id)
1907        .select_only()
1908        .column(database::Column::ResourceGroup)
1909        .into_tuple()
1910        .one(txn)
1911        .await?
1912        .ok_or_else(|| MetaError::catalog_id_not_found("database", database_id))?;
1913
1914    Ok(database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned()))
1915}
1916
1917pub async fn get_existing_job_resource_group<C>(
1918    txn: &C,
1919    streaming_job_id: ObjectId,
1920) -> MetaResult<String>
1921where
1922    C: ConnectionTrait,
1923{
1924    let (job_specific_resource_group, database_resource_group): (Option<String>, Option<String>) =
1925        StreamingJob::find_by_id(streaming_job_id)
1926            .select_only()
1927            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
1928            .join(JoinType::InnerJoin, object::Relation::Database2.def())
1929            .column(streaming_job::Column::SpecificResourceGroup)
1930            .column(database::Column::ResourceGroup)
1931            .into_tuple()
1932            .one(txn)
1933            .await?
1934            .ok_or_else(|| MetaError::catalog_id_not_found("streaming job", streaming_job_id))?;
1935
1936    Ok(job_specific_resource_group.unwrap_or_else(|| {
1937        database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned())
1938    }))
1939}
1940
1941pub fn filter_workers_by_resource_group(
1942    workers: &HashMap<u32, WorkerNode>,
1943    resource_group: &str,
1944) -> BTreeSet<WorkerId> {
1945    workers
1946        .iter()
1947        .filter(|&(_, worker)| {
1948            worker
1949                .resource_group()
1950                .map(|node_label| node_label.as_str() == resource_group)
1951                .unwrap_or(false)
1952        })
1953        .map(|(id, _)| (*id as WorkerId))
1954        .collect()
1955}
1956
1957/// `rename_relation_refer` updates the definition of relations that refer to the target one,
1958/// it commits the changes to the transaction and returns all the updated relations.
1959pub async fn rename_relation_refer(
1960    txn: &DatabaseTransaction,
1961    object_type: ObjectType,
1962    object_id: ObjectId,
1963    object_name: &str,
1964    old_name: &str,
1965) -> MetaResult<Vec<PbObject>> {
1966    use sea_orm::ActiveModelTrait;
1967
1968    use crate::controller::rename::alter_relation_rename_refs;
1969
1970    let mut to_update_relations = vec![];
1971    macro_rules! rename_relation_ref {
1972        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
1973            let (mut relation, obj) = $entity::find_by_id($object_id)
1974                .find_also_related(Object)
1975                .one(txn)
1976                .await?
1977                .unwrap();
1978            relation.definition =
1979                alter_relation_rename_refs(&relation.definition, old_name, object_name);
1980            let active_model = $table::ActiveModel {
1981                $identity: Set(relation.$identity),
1982                definition: Set(relation.definition.clone()),
1983                ..Default::default()
1984            };
1985            active_model.update(txn).await?;
1986            to_update_relations.push(PbObject {
1987                object_info: Some(PbObjectInfo::$entity(
1988                    ObjectModel(relation, obj.unwrap()).into(),
1989                )),
1990            });
1991        }};
1992    }
1993    let mut objs = get_referring_objects(object_id, txn).await?;
1994    if object_type == ObjectType::Table {
1995        let incoming_sinks: I32Array = Table::find_by_id(object_id)
1996            .select_only()
1997            .column(table::Column::IncomingSinks)
1998            .into_tuple()
1999            .one(txn)
2000            .await?
2001            .ok_or_else(|| MetaError::catalog_id_not_found("table", object_id))?;
2002
2003        objs.extend(
2004            incoming_sinks
2005                .into_inner()
2006                .into_iter()
2007                .map(|id| PartialObject {
2008                    oid: id,
2009                    obj_type: ObjectType::Sink,
2010                    schema_id: None,
2011                    database_id: None,
2012                }),
2013        );
2014    }
2015
2016    for obj in objs {
2017        match obj.obj_type {
2018            ObjectType::Table => rename_relation_ref!(Table, table, table_id, obj.oid),
2019            ObjectType::Sink => rename_relation_ref!(Sink, sink, sink_id, obj.oid),
2020            ObjectType::Subscription => {
2021                rename_relation_ref!(Subscription, subscription, subscription_id, obj.oid)
2022            }
2023            ObjectType::View => rename_relation_ref!(View, view, view_id, obj.oid),
2024            ObjectType::Index => {
2025                let index_table_id: Option<TableId> = Index::find_by_id(obj.oid)
2026                    .select_only()
2027                    .column(index::Column::IndexTableId)
2028                    .into_tuple()
2029                    .one(txn)
2030                    .await?;
2031                rename_relation_ref!(Table, table, table_id, index_table_id.unwrap());
2032            }
2033            _ => {
2034                bail!(
2035                    "only the table, sink, subscription, view and index will depend on other objects."
2036                )
2037            }
2038        }
2039    }
2040
2041    Ok(to_update_relations)
2042}
2043
2044/// Validate that subscription can be safely deleted, meeting any of the following conditions:
2045/// 1. The upstream table is not referred to by any cross-db mv.
2046/// 2. After deleting the subscription, the upstream table still has at least one subscription.
2047pub async fn validate_subscription_deletion<C>(txn: &C, subscription_id: ObjectId) -> MetaResult<()>
2048where
2049    C: ConnectionTrait,
2050{
2051    let upstream_table_id: ObjectId = Subscription::find_by_id(subscription_id)
2052        .select_only()
2053        .column(subscription::Column::DependentTableId)
2054        .into_tuple()
2055        .one(txn)
2056        .await?
2057        .ok_or_else(|| MetaError::catalog_id_not_found("subscription", subscription_id))?;
2058
2059    let cnt = Subscription::find()
2060        .filter(subscription::Column::DependentTableId.eq(upstream_table_id))
2061        .count(txn)
2062        .await?;
2063    if cnt > 1 {
2064        // Ensure that at least one subscription is remained for the upstream table
2065        // once the subscription is dropped.
2066        return Ok(());
2067    }
2068
2069    // Ensure that the upstream table is not referred by any cross-db mv.
2070    let obj_alias = Alias::new("o1");
2071    let used_by_alias = Alias::new("o2");
2072    let count = ObjectDependency::find()
2073        .join_as(
2074            JoinType::InnerJoin,
2075            object_dependency::Relation::Object2.def(),
2076            obj_alias.clone(),
2077        )
2078        .join_as(
2079            JoinType::InnerJoin,
2080            object_dependency::Relation::Object1.def(),
2081            used_by_alias.clone(),
2082        )
2083        .filter(
2084            object_dependency::Column::Oid
2085                .eq(upstream_table_id)
2086                .and(object_dependency::Column::UsedBy.ne(subscription_id))
2087                .and(
2088                    Expr::col((obj_alias, object::Column::DatabaseId))
2089                        .ne(Expr::col((used_by_alias, object::Column::DatabaseId))),
2090                ),
2091        )
2092        .count(txn)
2093        .await?;
2094
2095    if count != 0 {
2096        return Err(MetaError::permission_denied(format!(
2097            "Referenced by {} cross-db objects.",
2098            count
2099        )));
2100    }
2101
2102    Ok(())
2103}
2104
2105#[cfg(test)]
2106mod tests {
2107    use super::*;
2108
2109    #[test]
2110    fn test_extract_cdc_table_name() {
2111        let ddl1 = "CREATE TABLE t1 () FROM pg_source TABLE 'public.t1'";
2112        let ddl2 = "CREATE TABLE t2 (v1 int) FROM pg_source TABLE 'mydb.t2'";
2113        assert_eq!(
2114            extract_external_table_name_from_definition(ddl1),
2115            Some("public.t1".into())
2116        );
2117        assert_eq!(
2118            extract_external_table_name_from_definition(ddl2),
2119            Some("mydb.t2".into())
2120        );
2121    }
2122}