risingwave_meta/controller/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::{BTreeSet, HashMap, HashSet};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
23use risingwave_common::hash::{ActorMapping, VnodeBitmapExt, WorkerSlotId, WorkerSlotMapping};
24use risingwave_common::types::Datum;
25use risingwave_common::util::value_encoding::DatumToProtoExt;
26use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
27use risingwave_common::{bail, hash};
28use risingwave_meta_model::actor::ActorStatus;
29use risingwave_meta_model::fragment::DistributionType;
30use risingwave_meta_model::object::ObjectType;
31use risingwave_meta_model::prelude::*;
32use risingwave_meta_model::table::TableType;
33use risingwave_meta_model::user_privilege::Action;
34use risingwave_meta_model::{
35    ActorId, ColumnCatalogArray, DataTypeArray, DatabaseId, DispatcherType, FragmentId, I32Array,
36    JobStatus, ObjectId, PrivilegeId, SchemaId, SinkId, SourceId, StreamNode, StreamSourceInfo,
37    TableId, UserId, VnodeBitmap, WorkerId, actor, connection, database, fragment,
38    fragment_relation, function, index, object, object_dependency, schema, secret, sink, source,
39    streaming_job, subscription, table, user, user_default_privilege, user_privilege, view,
40};
41use risingwave_meta_model_migration::WithQuery;
42use risingwave_pb::catalog::{
43    PbConnection, PbDatabase, PbFunction, PbIndex, PbSchema, PbSecret, PbSink, PbSource,
44    PbSubscription, PbTable, PbView,
45};
46use risingwave_pb::common::WorkerNode;
47use risingwave_pb::expr::{PbExprNode, expr_node};
48use risingwave_pb::meta::object::PbObjectInfo;
49use risingwave_pb::meta::subscribe_response::Info as NotificationInfo;
50use risingwave_pb::meta::{
51    FragmentWorkerSlotMapping, PbFragmentWorkerSlotMapping, PbObject, PbObjectGroup,
52};
53use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn;
54use risingwave_pb::plan_common::{ColumnCatalog, DefaultColumnDesc};
55use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher, PbDispatcherType};
56use risingwave_pb::user::grant_privilege::{PbActionWithGrantOption, PbObject as PbGrantObject};
57use risingwave_pb::user::{PbAction, PbGrantPrivilege, PbUserInfo};
58use risingwave_sqlparser::ast::Statement as SqlStatement;
59use risingwave_sqlparser::parser::Parser;
60use sea_orm::sea_query::{
61    Alias, CommonTableExpression, Expr, Query, QueryStatementBuilder, SelectStatement, UnionType,
62    WithClause,
63};
64use sea_orm::{
65    ColumnTrait, ConnectionTrait, DatabaseTransaction, DerivePartialModel, EntityTrait,
66    FromQueryResult, IntoActiveModel, JoinType, Order, PaginatorTrait, QueryFilter, QuerySelect,
67    RelationTrait, Set, Statement,
68};
69use thiserror_ext::AsReport;
70
71use crate::barrier::SharedFragmentInfo;
72use crate::controller::ObjectModel;
73use crate::controller::fragment::FragmentTypeMaskExt;
74use crate::model::{FragmentActorDispatchers, FragmentDownstreamRelation};
75use crate::{MetaError, MetaResult};
76
77/// This function will construct a query using recursive cte to find all objects[(id, `obj_type`)] that are used by the given object.
78///
79/// # Examples
80///
81/// ```
82/// use risingwave_meta::controller::utils::construct_obj_dependency_query;
83/// use sea_orm::sea_query::*;
84/// use sea_orm::*;
85///
86/// let query = construct_obj_dependency_query(1);
87///
88/// assert_eq!(
89///     query.to_string(MysqlQueryBuilder),
90///     r#"WITH RECURSIVE `used_by_object_ids` (`used_by`) AS (SELECT `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `oid` FROM `object` WHERE `object`.`database_id` = 1 OR `object`.`schema_id` = 1) UNION ALL (SELECT `object_dependency`.`used_by` FROM `object_dependency` INNER JOIN `used_by_object_ids` ON `used_by_object_ids`.`used_by` = `oid`)) SELECT DISTINCT `oid`, `obj_type`, `schema_id`, `database_id` FROM `used_by_object_ids` INNER JOIN `object` ON `used_by_object_ids`.`used_by` = `oid` ORDER BY `oid` DESC"#
91/// );
92/// assert_eq!(
93///     query.to_string(PostgresQueryBuilder),
94///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1) UNION ALL (SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid")) SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
95/// );
96/// assert_eq!(
97///     query.to_string(SqliteQueryBuilder),
98///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1 UNION ALL SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid") SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
99/// );
100/// ```
101pub fn construct_obj_dependency_query(obj_id: ObjectId) -> WithQuery {
102    let cte_alias = Alias::new("used_by_object_ids");
103    let cte_return_alias = Alias::new("used_by");
104
105    let mut base_query = SelectStatement::new()
106        .column(object_dependency::Column::UsedBy)
107        .from(ObjectDependency)
108        .and_where(object_dependency::Column::Oid.eq(obj_id))
109        .to_owned();
110
111    let belonged_obj_query = SelectStatement::new()
112        .column(object::Column::Oid)
113        .from(Object)
114        .and_where(
115            object::Column::DatabaseId
116                .eq(obj_id)
117                .or(object::Column::SchemaId.eq(obj_id)),
118        )
119        .to_owned();
120
121    let cte_referencing = Query::select()
122        .column((ObjectDependency, object_dependency::Column::UsedBy))
123        .from(ObjectDependency)
124        .inner_join(
125            cte_alias.clone(),
126            Expr::col((cte_alias.clone(), cte_return_alias.clone()))
127                .equals(object_dependency::Column::Oid),
128        )
129        .to_owned();
130
131    let common_table_expr = CommonTableExpression::new()
132        .query(
133            base_query
134                .union(UnionType::All, belonged_obj_query)
135                .union(UnionType::All, cte_referencing)
136                .to_owned(),
137        )
138        .column(cte_return_alias.clone())
139        .table_name(cte_alias.clone())
140        .to_owned();
141
142    SelectStatement::new()
143        .distinct()
144        .columns([
145            object::Column::Oid,
146            object::Column::ObjType,
147            object::Column::SchemaId,
148            object::Column::DatabaseId,
149        ])
150        .from(cte_alias.clone())
151        .inner_join(
152            Object,
153            Expr::col((cte_alias, cte_return_alias.clone())).equals(object::Column::Oid),
154        )
155        .order_by(object::Column::Oid, Order::Desc)
156        .to_owned()
157        .with(
158            WithClause::new()
159                .recursive(true)
160                .cte(common_table_expr)
161                .to_owned(),
162        )
163        .to_owned()
164}
165
166/// This function will construct a query using recursive cte to find if dependent objects are already relying on the target table.
167///
168/// # Examples
169///
170/// ```
171/// use risingwave_meta::controller::utils::construct_sink_cycle_check_query;
172/// use sea_orm::sea_query::*;
173/// use sea_orm::*;
174///
175/// let query = construct_sink_cycle_check_query(1, vec![2, 3]);
176///
177/// assert_eq!(
178///     query.to_string(MysqlQueryBuilder),
179///     r#"WITH RECURSIVE `used_by_object_ids_with_sink` (`oid`, `used_by`) AS (SELECT `oid`, `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `obj_dependency_with_sink`.`oid`, `obj_dependency_with_sink`.`used_by` FROM (SELECT `oid`, `used_by` FROM `object_dependency` UNION ALL (SELECT `sink_id`, `target_table` FROM `sink` WHERE `sink`.`target_table` IS NOT NULL)) AS `obj_dependency_with_sink` INNER JOIN `used_by_object_ids_with_sink` ON `used_by_object_ids_with_sink`.`used_by` = `obj_dependency_with_sink`.`oid` WHERE `used_by_object_ids_with_sink`.`used_by` <> `used_by_object_ids_with_sink`.`oid`)) SELECT COUNT(`used_by_object_ids_with_sink`.`used_by`) FROM `used_by_object_ids_with_sink` WHERE `used_by_object_ids_with_sink`.`used_by` IN (2, 3)"#
180/// );
181/// assert_eq!(
182///     query.to_string(PostgresQueryBuilder),
183///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL (SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL)) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid")) SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
184/// );
185/// assert_eq!(
186///     query.to_string(SqliteQueryBuilder),
187///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid") SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
188/// );
189/// ```
190pub fn construct_sink_cycle_check_query(
191    target_table: ObjectId,
192    dependent_objects: Vec<ObjectId>,
193) -> WithQuery {
194    let cte_alias = Alias::new("used_by_object_ids_with_sink");
195    let depend_alias = Alias::new("obj_dependency_with_sink");
196
197    let mut base_query = SelectStatement::new()
198        .columns([
199            object_dependency::Column::Oid,
200            object_dependency::Column::UsedBy,
201        ])
202        .from(ObjectDependency)
203        .and_where(object_dependency::Column::Oid.eq(target_table))
204        .to_owned();
205
206    let query_sink_deps = SelectStatement::new()
207        .columns([sink::Column::SinkId, sink::Column::TargetTable])
208        .from(Sink)
209        .and_where(sink::Column::TargetTable.is_not_null())
210        .to_owned();
211
212    let cte_referencing = Query::select()
213        .column((depend_alias.clone(), object_dependency::Column::Oid))
214        .column((depend_alias.clone(), object_dependency::Column::UsedBy))
215        .from_subquery(
216            SelectStatement::new()
217                .columns([
218                    object_dependency::Column::Oid,
219                    object_dependency::Column::UsedBy,
220                ])
221                .from(ObjectDependency)
222                .union(UnionType::All, query_sink_deps)
223                .to_owned(),
224            depend_alias.clone(),
225        )
226        .inner_join(
227            cte_alias.clone(),
228            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).eq(Expr::col((
229                depend_alias.clone(),
230                object_dependency::Column::Oid,
231            ))),
232        )
233        .and_where(
234            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).ne(Expr::col((
235                cte_alias.clone(),
236                object_dependency::Column::Oid,
237            ))),
238        )
239        .to_owned();
240
241    let common_table_expr = CommonTableExpression::new()
242        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
243        .columns([
244            object_dependency::Column::Oid,
245            object_dependency::Column::UsedBy,
246        ])
247        .table_name(cte_alias.clone())
248        .to_owned();
249
250    SelectStatement::new()
251        .expr(Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).count())
252        .from(cte_alias.clone())
253        .and_where(
254            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy))
255                .is_in(dependent_objects),
256        )
257        .to_owned()
258        .with(
259            WithClause::new()
260                .recursive(true)
261                .cte(common_table_expr)
262                .to_owned(),
263        )
264        .to_owned()
265}
266
267#[derive(Clone, DerivePartialModel, FromQueryResult, Debug)]
268#[sea_orm(entity = "Object")]
269pub struct PartialObject {
270    pub oid: ObjectId,
271    pub obj_type: ObjectType,
272    pub schema_id: Option<SchemaId>,
273    pub database_id: Option<DatabaseId>,
274}
275
276#[derive(Clone, DerivePartialModel, FromQueryResult)]
277#[sea_orm(entity = "Fragment")]
278pub struct PartialFragmentStateTables {
279    pub fragment_id: FragmentId,
280    pub job_id: ObjectId,
281    pub state_table_ids: I32Array,
282}
283
284#[derive(Clone, DerivePartialModel, FromQueryResult)]
285#[sea_orm(entity = "Actor")]
286pub struct PartialActorLocation {
287    pub actor_id: ActorId,
288    pub fragment_id: FragmentId,
289    pub worker_id: WorkerId,
290    pub status: ActorStatus,
291}
292
293#[derive(FromQueryResult)]
294pub struct FragmentDesc {
295    pub fragment_id: FragmentId,
296    pub job_id: ObjectId,
297    pub fragment_type_mask: i32,
298    pub distribution_type: DistributionType,
299    pub state_table_ids: I32Array,
300    pub parallelism: i64,
301    pub vnode_count: i32,
302    pub stream_node: StreamNode,
303}
304
305/// List all objects that are using the given one in a cascade way. It runs a recursive CTE to find all the dependencies.
306pub async fn get_referring_objects_cascade<C>(
307    obj_id: ObjectId,
308    db: &C,
309) -> MetaResult<Vec<PartialObject>>
310where
311    C: ConnectionTrait,
312{
313    let query = construct_obj_dependency_query(obj_id);
314    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
315    let objects = PartialObject::find_by_statement(Statement::from_sql_and_values(
316        db.get_database_backend(),
317        sql,
318        values,
319    ))
320    .all(db)
321    .await?;
322    Ok(objects)
323}
324
325/// Check if create a sink with given dependent objects into the target table will cause a cycle, return true if it will.
326pub async fn check_sink_into_table_cycle<C>(
327    target_table: ObjectId,
328    dependent_objs: Vec<ObjectId>,
329    db: &C,
330) -> MetaResult<bool>
331where
332    C: ConnectionTrait,
333{
334    if dependent_objs.is_empty() {
335        return Ok(false);
336    }
337
338    // special check for self referencing
339    if dependent_objs.contains(&target_table) {
340        return Ok(true);
341    }
342
343    let query = construct_sink_cycle_check_query(target_table, dependent_objs);
344    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
345
346    let res = db
347        .query_one(Statement::from_sql_and_values(
348            db.get_database_backend(),
349            sql,
350            values,
351        ))
352        .await?
353        .unwrap();
354
355    let cnt: i64 = res.try_get_by(0)?;
356
357    Ok(cnt != 0)
358}
359
360/// `ensure_object_id` ensures the existence of target object in the cluster.
361pub async fn ensure_object_id<C>(
362    object_type: ObjectType,
363    obj_id: ObjectId,
364    db: &C,
365) -> MetaResult<()>
366where
367    C: ConnectionTrait,
368{
369    let count = Object::find_by_id(obj_id).count(db).await?;
370    if count == 0 {
371        return Err(MetaError::catalog_id_not_found(
372            object_type.as_str(),
373            obj_id,
374        ));
375    }
376    Ok(())
377}
378
379pub async fn ensure_job_not_canceled<C>(job_id: ObjectId, db: &C) -> MetaResult<()>
380where
381    C: ConnectionTrait,
382{
383    let count = Object::find_by_id(job_id).count(db).await?;
384    if count == 0 {
385        return Err(MetaError::cancelled(format!(
386            "job {} might be cancelled manually or by recovery",
387            job_id
388        )));
389    }
390    Ok(())
391}
392
393/// `ensure_user_id` ensures the existence of target user in the cluster.
394pub async fn ensure_user_id<C>(user_id: UserId, db: &C) -> MetaResult<()>
395where
396    C: ConnectionTrait,
397{
398    let count = User::find_by_id(user_id).count(db).await?;
399    if count == 0 {
400        return Err(anyhow!("user {} was concurrently dropped", user_id).into());
401    }
402    Ok(())
403}
404
405/// `check_database_name_duplicate` checks whether the database name is already used in the cluster.
406pub async fn check_database_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
407where
408    C: ConnectionTrait,
409{
410    let count = Database::find()
411        .filter(database::Column::Name.eq(name))
412        .count(db)
413        .await?;
414    if count > 0 {
415        assert_eq!(count, 1);
416        return Err(MetaError::catalog_duplicated("database", name));
417    }
418    Ok(())
419}
420
421/// `check_function_signature_duplicate` checks whether the function name and its signature is already used in the target namespace.
422pub async fn check_function_signature_duplicate<C>(
423    pb_function: &PbFunction,
424    db: &C,
425) -> MetaResult<()>
426where
427    C: ConnectionTrait,
428{
429    let count = Function::find()
430        .inner_join(Object)
431        .filter(
432            object::Column::DatabaseId
433                .eq(pb_function.database_id as DatabaseId)
434                .and(object::Column::SchemaId.eq(pb_function.schema_id as SchemaId))
435                .and(function::Column::Name.eq(&pb_function.name))
436                .and(
437                    function::Column::ArgTypes
438                        .eq(DataTypeArray::from(pb_function.arg_types.clone())),
439                ),
440        )
441        .count(db)
442        .await?;
443    if count > 0 {
444        assert_eq!(count, 1);
445        return Err(MetaError::catalog_duplicated("function", &pb_function.name));
446    }
447    Ok(())
448}
449
450/// `check_connection_name_duplicate` checks whether the connection name is already used in the target namespace.
451pub async fn check_connection_name_duplicate<C>(
452    pb_connection: &PbConnection,
453    db: &C,
454) -> MetaResult<()>
455where
456    C: ConnectionTrait,
457{
458    let count = Connection::find()
459        .inner_join(Object)
460        .filter(
461            object::Column::DatabaseId
462                .eq(pb_connection.database_id as DatabaseId)
463                .and(object::Column::SchemaId.eq(pb_connection.schema_id as SchemaId))
464                .and(connection::Column::Name.eq(&pb_connection.name)),
465        )
466        .count(db)
467        .await?;
468    if count > 0 {
469        assert_eq!(count, 1);
470        return Err(MetaError::catalog_duplicated(
471            "connection",
472            &pb_connection.name,
473        ));
474    }
475    Ok(())
476}
477
478pub async fn check_secret_name_duplicate<C>(pb_secret: &PbSecret, db: &C) -> MetaResult<()>
479where
480    C: ConnectionTrait,
481{
482    let count = Secret::find()
483        .inner_join(Object)
484        .filter(
485            object::Column::DatabaseId
486                .eq(pb_secret.database_id as DatabaseId)
487                .and(object::Column::SchemaId.eq(pb_secret.schema_id as SchemaId))
488                .and(secret::Column::Name.eq(&pb_secret.name)),
489        )
490        .count(db)
491        .await?;
492    if count > 0 {
493        assert_eq!(count, 1);
494        return Err(MetaError::catalog_duplicated("secret", &pb_secret.name));
495    }
496    Ok(())
497}
498
499pub async fn check_subscription_name_duplicate<C>(
500    pb_subscription: &PbSubscription,
501    db: &C,
502) -> MetaResult<()>
503where
504    C: ConnectionTrait,
505{
506    let count = Subscription::find()
507        .inner_join(Object)
508        .filter(
509            object::Column::DatabaseId
510                .eq(pb_subscription.database_id as DatabaseId)
511                .and(object::Column::SchemaId.eq(pb_subscription.schema_id as SchemaId))
512                .and(subscription::Column::Name.eq(&pb_subscription.name)),
513        )
514        .count(db)
515        .await?;
516    if count > 0 {
517        assert_eq!(count, 1);
518        return Err(MetaError::catalog_duplicated(
519            "subscription",
520            &pb_subscription.name,
521        ));
522    }
523    Ok(())
524}
525
526/// `check_user_name_duplicate` checks whether the user is already existed in the cluster.
527pub async fn check_user_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
528where
529    C: ConnectionTrait,
530{
531    let count = User::find()
532        .filter(user::Column::Name.eq(name))
533        .count(db)
534        .await?;
535    if count > 0 {
536        assert_eq!(count, 1);
537        return Err(MetaError::catalog_duplicated("user", name));
538    }
539    Ok(())
540}
541
542/// `check_relation_name_duplicate` checks whether the relation name is already used in the target namespace.
543pub async fn check_relation_name_duplicate<C>(
544    name: &str,
545    database_id: DatabaseId,
546    schema_id: SchemaId,
547    db: &C,
548) -> MetaResult<()>
549where
550    C: ConnectionTrait,
551{
552    macro_rules! check_duplicated {
553        ($obj_type:expr, $entity:ident, $table:ident) => {
554            let object_id = Object::find()
555                .select_only()
556                .column(object::Column::Oid)
557                .inner_join($entity)
558                .filter(
559                    object::Column::DatabaseId
560                        .eq(Some(database_id))
561                        .and(object::Column::SchemaId.eq(Some(schema_id)))
562                        .and($table::Column::Name.eq(name)),
563                )
564                .into_tuple::<ObjectId>()
565                .one(db)
566                .await?;
567            if let Some(oid) = object_id {
568                let check_creation = if $obj_type == ObjectType::View {
569                    false
570                } else if $obj_type == ObjectType::Source {
571                    let source_info = Source::find_by_id(oid)
572                        .select_only()
573                        .column(source::Column::SourceInfo)
574                        .into_tuple::<Option<StreamSourceInfo>>()
575                        .one(db)
576                        .await?
577                        .unwrap();
578                    source_info.map_or(false, |info| info.to_protobuf().is_shared())
579                } else {
580                    true
581                };
582                return if check_creation
583                    && !matches!(
584                        StreamingJob::find_by_id(oid)
585                            .select_only()
586                            .column(streaming_job::Column::JobStatus)
587                            .into_tuple::<JobStatus>()
588                            .one(db)
589                            .await?,
590                        Some(JobStatus::Created)
591                    ) {
592                    Err(MetaError::catalog_under_creation(
593                        $obj_type.as_str(),
594                        name,
595                        oid,
596                    ))
597                } else {
598                    Err(MetaError::catalog_duplicated($obj_type.as_str(), name))
599                };
600            }
601        };
602    }
603    check_duplicated!(ObjectType::Table, Table, table);
604    check_duplicated!(ObjectType::Source, Source, source);
605    check_duplicated!(ObjectType::Sink, Sink, sink);
606    check_duplicated!(ObjectType::Index, Index, index);
607    check_duplicated!(ObjectType::View, View, view);
608
609    Ok(())
610}
611
612/// `check_schema_name_duplicate` checks whether the schema name is already used in the target database.
613pub async fn check_schema_name_duplicate<C>(
614    name: &str,
615    database_id: DatabaseId,
616    db: &C,
617) -> MetaResult<()>
618where
619    C: ConnectionTrait,
620{
621    let count = Object::find()
622        .inner_join(Schema)
623        .filter(
624            object::Column::ObjType
625                .eq(ObjectType::Schema)
626                .and(object::Column::DatabaseId.eq(Some(database_id)))
627                .and(schema::Column::Name.eq(name)),
628        )
629        .count(db)
630        .await?;
631    if count != 0 {
632        return Err(MetaError::catalog_duplicated("schema", name));
633    }
634
635    Ok(())
636}
637
638/// `check_object_refer_for_drop` checks whether the object is used by other objects except indexes.
639/// It returns an error that contains the details of the referring objects if it is used by others.
640pub async fn check_object_refer_for_drop<C>(
641    object_type: ObjectType,
642    object_id: ObjectId,
643    db: &C,
644) -> MetaResult<()>
645where
646    C: ConnectionTrait,
647{
648    // Ignore indexes.
649    let count = if object_type == ObjectType::Table {
650        ObjectDependency::find()
651            .join(
652                JoinType::InnerJoin,
653                object_dependency::Relation::Object1.def(),
654            )
655            .filter(
656                object_dependency::Column::Oid
657                    .eq(object_id)
658                    .and(object::Column::ObjType.ne(ObjectType::Index)),
659            )
660            .count(db)
661            .await?
662    } else {
663        ObjectDependency::find()
664            .filter(object_dependency::Column::Oid.eq(object_id))
665            .count(db)
666            .await?
667    };
668    if count != 0 {
669        // find the name of all objects that are using the given one.
670        let referring_objects = get_referring_objects(object_id, db).await?;
671        let referring_objs_map = referring_objects
672            .into_iter()
673            .filter(|o| o.obj_type != ObjectType::Index)
674            .into_group_map_by(|o| o.obj_type);
675        let mut details = vec![];
676        for (obj_type, objs) in referring_objs_map {
677            match obj_type {
678                ObjectType::Table => {
679                    let tables: Vec<(String, String)> = Object::find()
680                        .join(JoinType::InnerJoin, object::Relation::Table.def())
681                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
682                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
683                        .select_only()
684                        .column(schema::Column::Name)
685                        .column(table::Column::Name)
686                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
687                        .into_tuple()
688                        .all(db)
689                        .await?;
690                    details.extend(tables.into_iter().map(|(schema_name, table_name)| {
691                        format!(
692                            "materialized view {}.{} depends on it",
693                            schema_name, table_name
694                        )
695                    }));
696                }
697                ObjectType::Sink => {
698                    let sinks: Vec<(String, String)> = Object::find()
699                        .join(JoinType::InnerJoin, object::Relation::Sink.def())
700                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
701                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
702                        .select_only()
703                        .column(schema::Column::Name)
704                        .column(sink::Column::Name)
705                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
706                        .into_tuple()
707                        .all(db)
708                        .await?;
709                    details.extend(sinks.into_iter().map(|(schema_name, sink_name)| {
710                        format!("sink {}.{} depends on it", schema_name, sink_name)
711                    }));
712                }
713                ObjectType::View => {
714                    let views: Vec<(String, String)> = Object::find()
715                        .join(JoinType::InnerJoin, object::Relation::View.def())
716                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
717                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
718                        .select_only()
719                        .column(schema::Column::Name)
720                        .column(view::Column::Name)
721                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
722                        .into_tuple()
723                        .all(db)
724                        .await?;
725                    details.extend(views.into_iter().map(|(schema_name, view_name)| {
726                        format!("view {}.{} depends on it", schema_name, view_name)
727                    }));
728                }
729                ObjectType::Subscription => {
730                    let subscriptions: Vec<(String, String)> = Object::find()
731                        .join(JoinType::InnerJoin, object::Relation::Subscription.def())
732                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
733                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
734                        .select_only()
735                        .column(schema::Column::Name)
736                        .column(subscription::Column::Name)
737                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
738                        .into_tuple()
739                        .all(db)
740                        .await?;
741                    details.extend(subscriptions.into_iter().map(
742                        |(schema_name, subscription_name)| {
743                            format!(
744                                "subscription {}.{} depends on it",
745                                schema_name, subscription_name
746                            )
747                        },
748                    ));
749                }
750                ObjectType::Source => {
751                    let sources: Vec<(String, String)> = Object::find()
752                        .join(JoinType::InnerJoin, object::Relation::Source.def())
753                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
754                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
755                        .select_only()
756                        .column(schema::Column::Name)
757                        .column(source::Column::Name)
758                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
759                        .into_tuple()
760                        .all(db)
761                        .await?;
762                    details.extend(sources.into_iter().map(|(schema_name, view_name)| {
763                        format!("source {}.{} depends on it", schema_name, view_name)
764                    }));
765                }
766                ObjectType::Connection => {
767                    let connections: Vec<(String, String)> = Object::find()
768                        .join(JoinType::InnerJoin, object::Relation::Connection.def())
769                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
770                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
771                        .select_only()
772                        .column(schema::Column::Name)
773                        .column(connection::Column::Name)
774                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
775                        .into_tuple()
776                        .all(db)
777                        .await?;
778                    details.extend(connections.into_iter().map(|(schema_name, view_name)| {
779                        format!("connection {}.{} depends on it", schema_name, view_name)
780                    }));
781                }
782                // only the table, source, sink, subscription, view, connection and index will depend on other objects.
783                _ => bail!("unexpected referring object type: {}", obj_type.as_str()),
784            }
785        }
786
787        return Err(MetaError::permission_denied(format!(
788            "{} used by {} other objects. \nDETAIL: {}\n\
789            {}",
790            object_type.as_str(),
791            count,
792            details.join("\n"),
793            match object_type {
794                ObjectType::Function | ObjectType::Connection | ObjectType::Secret =>
795                    "HINT: DROP the dependent objects first.",
796                ObjectType::Database | ObjectType::Schema => unreachable!(),
797                _ => "HINT:  Use DROP ... CASCADE to drop the dependent objects too.",
798            }
799        )));
800    }
801    Ok(())
802}
803
804/// List all objects that are using the given one.
805pub async fn get_referring_objects<C>(object_id: ObjectId, db: &C) -> MetaResult<Vec<PartialObject>>
806where
807    C: ConnectionTrait,
808{
809    let objs = ObjectDependency::find()
810        .filter(object_dependency::Column::Oid.eq(object_id))
811        .join(
812            JoinType::InnerJoin,
813            object_dependency::Relation::Object1.def(),
814        )
815        .into_partial_model()
816        .all(db)
817        .await?;
818
819    Ok(objs)
820}
821
822/// `ensure_schema_empty` ensures that the schema is empty, used by `DROP SCHEMA`.
823pub async fn ensure_schema_empty<C>(schema_id: SchemaId, db: &C) -> MetaResult<()>
824where
825    C: ConnectionTrait,
826{
827    let count = Object::find()
828        .filter(object::Column::SchemaId.eq(Some(schema_id)))
829        .count(db)
830        .await?;
831    if count != 0 {
832        return Err(MetaError::permission_denied("schema is not empty"));
833    }
834
835    Ok(())
836}
837
838/// `list_user_info_by_ids` lists all users' info by their ids.
839pub async fn list_user_info_by_ids<C>(
840    user_ids: impl IntoIterator<Item = UserId>,
841    db: &C,
842) -> MetaResult<Vec<PbUserInfo>>
843where
844    C: ConnectionTrait,
845{
846    let mut user_infos = vec![];
847    for user_id in user_ids {
848        let user = User::find_by_id(user_id)
849            .one(db)
850            .await?
851            .ok_or_else(|| MetaError::catalog_id_not_found("user", user_id))?;
852        let mut user_info: PbUserInfo = user.into();
853        user_info.grant_privileges = get_user_privilege(user_id, db).await?;
854        user_infos.push(user_info);
855    }
856    Ok(user_infos)
857}
858
859/// `get_object_owner` returns the owner of the given object.
860pub async fn get_object_owner<C>(object_id: ObjectId, db: &C) -> MetaResult<UserId>
861where
862    C: ConnectionTrait,
863{
864    let obj_owner: UserId = Object::find_by_id(object_id)
865        .select_only()
866        .column(object::Column::OwnerId)
867        .into_tuple()
868        .one(db)
869        .await?
870        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
871    Ok(obj_owner)
872}
873
874/// `construct_privilege_dependency_query` constructs a query to find all privileges that are dependent on the given one.
875///
876/// # Examples
877///
878/// ```
879/// use risingwave_meta::controller::utils::construct_privilege_dependency_query;
880/// use sea_orm::sea_query::*;
881/// use sea_orm::*;
882///
883/// let query = construct_privilege_dependency_query(vec![1, 2, 3]);
884///
885/// assert_eq!(
886///    query.to_string(MysqlQueryBuilder),
887///   r#"WITH RECURSIVE `granted_privilege_ids` (`id`, `user_id`) AS (SELECT `id`, `user_id` FROM `user_privilege` WHERE `user_privilege`.`id` IN (1, 2, 3) UNION ALL (SELECT `user_privilege`.`id`, `user_privilege`.`user_id` FROM `user_privilege` INNER JOIN `granted_privilege_ids` ON `granted_privilege_ids`.`id` = `dependent_id`)) SELECT `id`, `user_id` FROM `granted_privilege_ids`"#
888/// );
889/// assert_eq!(
890///   query.to_string(PostgresQueryBuilder),
891///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL (SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id")) SELECT "id", "user_id" FROM "granted_privilege_ids""#
892/// );
893/// assert_eq!(
894///  query.to_string(SqliteQueryBuilder),
895///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id") SELECT "id", "user_id" FROM "granted_privilege_ids""#
896/// );
897/// ```
898pub fn construct_privilege_dependency_query(ids: Vec<PrivilegeId>) -> WithQuery {
899    let cte_alias = Alias::new("granted_privilege_ids");
900    let cte_return_privilege_alias = Alias::new("id");
901    let cte_return_user_alias = Alias::new("user_id");
902
903    let mut base_query = SelectStatement::new()
904        .columns([user_privilege::Column::Id, user_privilege::Column::UserId])
905        .from(UserPrivilege)
906        .and_where(user_privilege::Column::Id.is_in(ids))
907        .to_owned();
908
909    let cte_referencing = Query::select()
910        .columns([
911            (UserPrivilege, user_privilege::Column::Id),
912            (UserPrivilege, user_privilege::Column::UserId),
913        ])
914        .from(UserPrivilege)
915        .inner_join(
916            cte_alias.clone(),
917            Expr::col((cte_alias.clone(), cte_return_privilege_alias.clone()))
918                .equals(user_privilege::Column::DependentId),
919        )
920        .to_owned();
921
922    let common_table_expr = CommonTableExpression::new()
923        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
924        .columns([
925            cte_return_privilege_alias.clone(),
926            cte_return_user_alias.clone(),
927        ])
928        .table_name(cte_alias.clone())
929        .to_owned();
930
931    SelectStatement::new()
932        .columns([cte_return_privilege_alias, cte_return_user_alias])
933        .from(cte_alias.clone())
934        .to_owned()
935        .with(
936            WithClause::new()
937                .recursive(true)
938                .cte(common_table_expr)
939                .to_owned(),
940        )
941        .to_owned()
942}
943
944pub async fn get_internal_tables_by_id<C>(job_id: ObjectId, db: &C) -> MetaResult<Vec<TableId>>
945where
946    C: ConnectionTrait,
947{
948    let table_ids: Vec<TableId> = Table::find()
949        .select_only()
950        .column(table::Column::TableId)
951        .filter(
952            table::Column::TableType
953                .eq(TableType::Internal)
954                .and(table::Column::BelongsToJobId.eq(job_id)),
955        )
956        .into_tuple()
957        .all(db)
958        .await?;
959    Ok(table_ids)
960}
961
962pub async fn get_index_state_tables_by_table_id<C>(
963    table_id: TableId,
964    db: &C,
965) -> MetaResult<Vec<TableId>>
966where
967    C: ConnectionTrait,
968{
969    let mut index_table_ids: Vec<TableId> = Index::find()
970        .select_only()
971        .column(index::Column::IndexTableId)
972        .filter(index::Column::PrimaryTableId.eq(table_id))
973        .into_tuple()
974        .all(db)
975        .await?;
976
977    if !index_table_ids.is_empty() {
978        let internal_table_ids: Vec<TableId> = Table::find()
979            .select_only()
980            .column(table::Column::TableId)
981            .filter(
982                table::Column::TableType
983                    .eq(TableType::Internal)
984                    .and(table::Column::BelongsToJobId.is_in(index_table_ids.clone())),
985            )
986            .into_tuple()
987            .all(db)
988            .await?;
989
990        index_table_ids.extend(internal_table_ids.into_iter());
991    }
992
993    Ok(index_table_ids)
994}
995
996#[derive(Clone, DerivePartialModel, FromQueryResult)]
997#[sea_orm(entity = "UserPrivilege")]
998pub struct PartialUserPrivilege {
999    pub id: PrivilegeId,
1000    pub user_id: UserId,
1001}
1002
1003pub async fn get_referring_privileges_cascade<C>(
1004    ids: Vec<PrivilegeId>,
1005    db: &C,
1006) -> MetaResult<Vec<PartialUserPrivilege>>
1007where
1008    C: ConnectionTrait,
1009{
1010    let query = construct_privilege_dependency_query(ids);
1011    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
1012    let privileges = PartialUserPrivilege::find_by_statement(Statement::from_sql_and_values(
1013        db.get_database_backend(),
1014        sql,
1015        values,
1016    ))
1017    .all(db)
1018    .await?;
1019
1020    Ok(privileges)
1021}
1022
1023/// `ensure_privileges_not_referred` ensures that the privileges are not granted to any other users.
1024pub async fn ensure_privileges_not_referred<C>(ids: Vec<PrivilegeId>, db: &C) -> MetaResult<()>
1025where
1026    C: ConnectionTrait,
1027{
1028    let count = UserPrivilege::find()
1029        .filter(user_privilege::Column::DependentId.is_in(ids))
1030        .count(db)
1031        .await?;
1032    if count != 0 {
1033        return Err(MetaError::permission_denied(format!(
1034            "privileges granted to {} other ones.",
1035            count
1036        )));
1037    }
1038    Ok(())
1039}
1040
1041/// `get_user_privilege` returns the privileges of the given user.
1042pub async fn get_user_privilege<C>(user_id: UserId, db: &C) -> MetaResult<Vec<PbGrantPrivilege>>
1043where
1044    C: ConnectionTrait,
1045{
1046    let user_privileges = UserPrivilege::find()
1047        .find_also_related(Object)
1048        .filter(user_privilege::Column::UserId.eq(user_id))
1049        .all(db)
1050        .await?;
1051    Ok(user_privileges
1052        .into_iter()
1053        .map(|(privilege, object)| {
1054            let object = object.unwrap();
1055            let oid = object.oid as _;
1056            let obj = match object.obj_type {
1057                ObjectType::Database => PbGrantObject::DatabaseId(oid),
1058                ObjectType::Schema => PbGrantObject::SchemaId(oid),
1059                ObjectType::Table | ObjectType::Index => PbGrantObject::TableId(oid),
1060                ObjectType::Source => PbGrantObject::SourceId(oid),
1061                ObjectType::Sink => PbGrantObject::SinkId(oid),
1062                ObjectType::View => PbGrantObject::ViewId(oid),
1063                ObjectType::Function => PbGrantObject::FunctionId(oid),
1064                ObjectType::Connection => PbGrantObject::ConnectionId(oid),
1065                ObjectType::Subscription => PbGrantObject::SubscriptionId(oid),
1066                ObjectType::Secret => PbGrantObject::SecretId(oid),
1067            };
1068            PbGrantPrivilege {
1069                action_with_opts: vec![PbActionWithGrantOption {
1070                    action: PbAction::from(privilege.action) as _,
1071                    with_grant_option: privilege.with_grant_option,
1072                    granted_by: privilege.granted_by as _,
1073                }],
1074                object: Some(obj),
1075            }
1076        })
1077        .collect())
1078}
1079
1080pub async fn get_table_columns(
1081    txn: &impl ConnectionTrait,
1082    id: TableId,
1083) -> MetaResult<ColumnCatalogArray> {
1084    let columns = Table::find_by_id(id)
1085        .select_only()
1086        .columns([table::Column::Columns])
1087        .into_tuple::<ColumnCatalogArray>()
1088        .one(txn)
1089        .await?
1090        .ok_or_else(|| MetaError::catalog_id_not_found("table", id))?;
1091    Ok(columns)
1092}
1093
1094/// `grant_default_privileges_automatically` grants default privileges automatically
1095/// for the given new object. It returns the list of user infos whose privileges are updated.
1096pub async fn grant_default_privileges_automatically<C>(
1097    db: &C,
1098    object_id: ObjectId,
1099) -> MetaResult<Vec<PbUserInfo>>
1100where
1101    C: ConnectionTrait,
1102{
1103    let object = Object::find_by_id(object_id)
1104        .one(db)
1105        .await?
1106        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
1107    assert_ne!(object.obj_type, ObjectType::Database);
1108
1109    let for_mview_filter = if object.obj_type == ObjectType::Table {
1110        let table_type = Table::find_by_id(object_id)
1111            .select_only()
1112            .column(table::Column::TableType)
1113            .into_tuple::<TableType>()
1114            .one(db)
1115            .await?
1116            .ok_or_else(|| MetaError::catalog_id_not_found("table", object_id))?;
1117        user_default_privilege::Column::ForMaterializedView
1118            .eq(table_type == TableType::MaterializedView)
1119    } else {
1120        user_default_privilege::Column::ForMaterializedView.eq(false)
1121    };
1122    let schema_filter = if let Some(schema_id) = &object.schema_id {
1123        user_default_privilege::Column::SchemaId.eq(*schema_id)
1124    } else {
1125        user_default_privilege::Column::SchemaId.is_null()
1126    };
1127
1128    let default_privileges: Vec<(UserId, UserId, Action, bool)> = UserDefaultPrivilege::find()
1129        .select_only()
1130        .columns([
1131            user_default_privilege::Column::Grantee,
1132            user_default_privilege::Column::GrantedBy,
1133            user_default_privilege::Column::Action,
1134            user_default_privilege::Column::WithGrantOption,
1135        ])
1136        .filter(
1137            user_default_privilege::Column::DatabaseId
1138                .eq(object.database_id.unwrap())
1139                .and(schema_filter)
1140                .and(user_default_privilege::Column::UserId.eq(object.owner_id))
1141                .and(user_default_privilege::Column::ObjectType.eq(object.obj_type))
1142                .and(for_mview_filter),
1143        )
1144        .into_tuple()
1145        .all(db)
1146        .await?;
1147    if default_privileges.is_empty() {
1148        return Ok(vec![]);
1149    }
1150
1151    let updated_user_ids = default_privileges
1152        .iter()
1153        .map(|(grantee, _, _, _)| *grantee)
1154        .collect::<HashSet<_>>();
1155
1156    let internal_table_ids = get_internal_tables_by_id(object_id, db).await?;
1157
1158    for (grantee, granted_by, action, with_grant_option) in default_privileges {
1159        UserPrivilege::insert(user_privilege::ActiveModel {
1160            user_id: Set(grantee),
1161            oid: Set(object_id),
1162            granted_by: Set(granted_by),
1163            action: Set(action),
1164            with_grant_option: Set(with_grant_option),
1165            ..Default::default()
1166        })
1167        .exec(db)
1168        .await?;
1169        if action == Action::Select && !internal_table_ids.is_empty() {
1170            // Grant SELECT privilege for internal tables if the action is SELECT.
1171            for internal_table_id in &internal_table_ids {
1172                UserPrivilege::insert(user_privilege::ActiveModel {
1173                    user_id: Set(grantee),
1174                    oid: Set(*internal_table_id as _),
1175                    granted_by: Set(granted_by),
1176                    action: Set(Action::Select),
1177                    with_grant_option: Set(with_grant_option),
1178                    ..Default::default()
1179                })
1180                .exec(db)
1181                .await?;
1182            }
1183        }
1184    }
1185
1186    let updated_user_infos = list_user_info_by_ids(updated_user_ids, db).await?;
1187    Ok(updated_user_infos)
1188}
1189
1190// todo: remove it after migrated to sql backend.
1191pub fn extract_grant_obj_id(object: &PbGrantObject) -> ObjectId {
1192    match object {
1193        PbGrantObject::DatabaseId(id)
1194        | PbGrantObject::SchemaId(id)
1195        | PbGrantObject::TableId(id)
1196        | PbGrantObject::SourceId(id)
1197        | PbGrantObject::SinkId(id)
1198        | PbGrantObject::ViewId(id)
1199        | PbGrantObject::FunctionId(id)
1200        | PbGrantObject::SubscriptionId(id)
1201        | PbGrantObject::ConnectionId(id)
1202        | PbGrantObject::SecretId(id) => *id as _,
1203    }
1204}
1205
1206pub async fn insert_fragment_relations(
1207    db: &impl ConnectionTrait,
1208    downstream_fragment_relations: &FragmentDownstreamRelation,
1209) -> MetaResult<()> {
1210    for (upstream_fragment_id, downstreams) in downstream_fragment_relations {
1211        for downstream in downstreams {
1212            let relation = fragment_relation::Model {
1213                source_fragment_id: *upstream_fragment_id as _,
1214                target_fragment_id: downstream.downstream_fragment_id as _,
1215                dispatcher_type: downstream.dispatcher_type,
1216                dist_key_indices: downstream
1217                    .dist_key_indices
1218                    .iter()
1219                    .map(|idx| *idx as i32)
1220                    .collect_vec()
1221                    .into(),
1222                output_indices: downstream
1223                    .output_mapping
1224                    .indices
1225                    .iter()
1226                    .map(|idx| *idx as i32)
1227                    .collect_vec()
1228                    .into(),
1229                output_type_mapping: Some(downstream.output_mapping.types.clone().into()),
1230            };
1231            FragmentRelation::insert(relation.into_active_model())
1232                .exec(db)
1233                .await?;
1234        }
1235    }
1236    Ok(())
1237}
1238
1239pub async fn get_fragment_actor_dispatchers<C>(
1240    db: &C,
1241    fragment_ids: Vec<FragmentId>,
1242) -> MetaResult<FragmentActorDispatchers>
1243where
1244    C: ConnectionTrait,
1245{
1246    type FragmentActorInfo = (
1247        DistributionType,
1248        Arc<HashMap<crate::model::ActorId, Option<Bitmap>>>,
1249    );
1250    let mut fragment_actor_cache: HashMap<FragmentId, FragmentActorInfo> = HashMap::new();
1251    let get_fragment_actors = |fragment_id: FragmentId| async move {
1252        let result: MetaResult<FragmentActorInfo> = try {
1253            let mut fragment_actors = Fragment::find_by_id(fragment_id)
1254                .find_with_related(Actor)
1255                .filter(actor::Column::Status.eq(ActorStatus::Running))
1256                .all(db)
1257                .await?;
1258            if fragment_actors.is_empty() {
1259                return Err(anyhow!("failed to find fragment: {}", fragment_id).into());
1260            }
1261            assert_eq!(
1262                fragment_actors.len(),
1263                1,
1264                "find multiple fragment {:?}",
1265                fragment_actors
1266            );
1267            let (fragment, actors) = fragment_actors.pop().unwrap();
1268            (
1269                fragment.distribution_type,
1270                Arc::new(
1271                    actors
1272                        .into_iter()
1273                        .map(|actor| {
1274                            (
1275                                actor.actor_id as _,
1276                                actor
1277                                    .vnode_bitmap
1278                                    .map(|bitmap| Bitmap::from(bitmap.to_protobuf())),
1279                            )
1280                        })
1281                        .collect(),
1282                ),
1283            )
1284        };
1285        result
1286    };
1287    let fragment_relations = FragmentRelation::find()
1288        .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids))
1289        .all(db)
1290        .await?;
1291
1292    let mut actor_dispatchers_map: HashMap<_, HashMap<_, Vec<_>>> = HashMap::new();
1293    for fragment_relation::Model {
1294        source_fragment_id,
1295        target_fragment_id,
1296        dispatcher_type,
1297        dist_key_indices,
1298        output_indices,
1299        output_type_mapping,
1300    } in fragment_relations
1301    {
1302        let (source_fragment_distribution, source_fragment_actors) = {
1303            let (distribution, actors) = {
1304                match fragment_actor_cache.entry(source_fragment_id) {
1305                    Entry::Occupied(entry) => entry.into_mut(),
1306                    Entry::Vacant(entry) => {
1307                        entry.insert(get_fragment_actors(source_fragment_id).await?)
1308                    }
1309                }
1310            };
1311            (*distribution, actors.clone())
1312        };
1313        let (target_fragment_distribution, target_fragment_actors) = {
1314            let (distribution, actors) = {
1315                match fragment_actor_cache.entry(target_fragment_id) {
1316                    Entry::Occupied(entry) => entry.into_mut(),
1317                    Entry::Vacant(entry) => {
1318                        entry.insert(get_fragment_actors(target_fragment_id).await?)
1319                    }
1320                }
1321            };
1322            (*distribution, actors.clone())
1323        };
1324        let output_mapping = PbDispatchOutputMapping {
1325            indices: output_indices.into_u32_array(),
1326            types: output_type_mapping.unwrap_or_default().to_protobuf(),
1327        };
1328        let dispatchers = compose_dispatchers(
1329            source_fragment_distribution,
1330            &source_fragment_actors,
1331            target_fragment_id as _,
1332            target_fragment_distribution,
1333            &target_fragment_actors,
1334            dispatcher_type,
1335            dist_key_indices.into_u32_array(),
1336            output_mapping,
1337        );
1338        let actor_dispatchers_map = actor_dispatchers_map
1339            .entry(source_fragment_id as _)
1340            .or_default();
1341        for (actor_id, dispatchers) in dispatchers {
1342            actor_dispatchers_map
1343                .entry(actor_id as _)
1344                .or_default()
1345                .push(dispatchers);
1346        }
1347    }
1348    Ok(actor_dispatchers_map)
1349}
1350
1351pub fn compose_dispatchers(
1352    source_fragment_distribution: DistributionType,
1353    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1354    target_fragment_id: crate::model::FragmentId,
1355    target_fragment_distribution: DistributionType,
1356    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1357    dispatcher_type: DispatcherType,
1358    dist_key_indices: Vec<u32>,
1359    output_mapping: PbDispatchOutputMapping,
1360) -> HashMap<crate::model::ActorId, PbDispatcher> {
1361    match dispatcher_type {
1362        DispatcherType::Hash => {
1363            let dispatcher = PbDispatcher {
1364                r#type: PbDispatcherType::from(dispatcher_type) as _,
1365                dist_key_indices: dist_key_indices.clone(),
1366                output_mapping: output_mapping.into(),
1367                hash_mapping: Some(
1368                    ActorMapping::from_bitmaps(
1369                        &target_fragment_actors
1370                            .iter()
1371                            .map(|(actor_id, bitmap)| {
1372                                (
1373                                    *actor_id as _,
1374                                    bitmap
1375                                        .clone()
1376                                        .expect("downstream hash dispatch must have distribution"),
1377                                )
1378                            })
1379                            .collect(),
1380                    )
1381                    .to_protobuf(),
1382                ),
1383                dispatcher_id: target_fragment_id as _,
1384                downstream_actor_id: target_fragment_actors
1385                    .keys()
1386                    .map(|actor_id| *actor_id as _)
1387                    .collect(),
1388            };
1389            source_fragment_actors
1390                .keys()
1391                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1392                .collect()
1393        }
1394        DispatcherType::Broadcast | DispatcherType::Simple => {
1395            let dispatcher = PbDispatcher {
1396                r#type: PbDispatcherType::from(dispatcher_type) as _,
1397                dist_key_indices: dist_key_indices.clone(),
1398                output_mapping: output_mapping.into(),
1399                hash_mapping: None,
1400                dispatcher_id: target_fragment_id as _,
1401                downstream_actor_id: target_fragment_actors
1402                    .keys()
1403                    .map(|actor_id| *actor_id as _)
1404                    .collect(),
1405            };
1406            source_fragment_actors
1407                .keys()
1408                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1409                .collect()
1410        }
1411        DispatcherType::NoShuffle => resolve_no_shuffle_actor_dispatcher(
1412            source_fragment_distribution,
1413            source_fragment_actors,
1414            target_fragment_distribution,
1415            target_fragment_actors,
1416        )
1417        .into_iter()
1418        .map(|(upstream_actor_id, downstream_actor_id)| {
1419            (
1420                upstream_actor_id,
1421                PbDispatcher {
1422                    r#type: PbDispatcherType::NoShuffle as _,
1423                    dist_key_indices: dist_key_indices.clone(),
1424                    output_mapping: output_mapping.clone().into(),
1425                    hash_mapping: None,
1426                    dispatcher_id: target_fragment_id as _,
1427                    downstream_actor_id: vec![downstream_actor_id as _],
1428                },
1429            )
1430        })
1431        .collect(),
1432    }
1433}
1434
1435/// return (`upstream_actor_id` -> `downstream_actor_id`)
1436pub fn resolve_no_shuffle_actor_dispatcher(
1437    source_fragment_distribution: DistributionType,
1438    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1439    target_fragment_distribution: DistributionType,
1440    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1441) -> Vec<(crate::model::ActorId, crate::model::ActorId)> {
1442    assert_eq!(source_fragment_distribution, target_fragment_distribution);
1443    assert_eq!(
1444        source_fragment_actors.len(),
1445        target_fragment_actors.len(),
1446        "no-shuffle should have equal upstream downstream actor count: {:?} {:?}",
1447        source_fragment_actors,
1448        target_fragment_actors
1449    );
1450    match source_fragment_distribution {
1451        DistributionType::Single => {
1452            let assert_singleton = |bitmap: &Option<Bitmap>| {
1453                assert!(
1454                    bitmap.as_ref().map(|bitmap| bitmap.all()).unwrap_or(true),
1455                    "not singleton: {:?}",
1456                    bitmap
1457                );
1458            };
1459            assert_eq!(
1460                source_fragment_actors.len(),
1461                1,
1462                "singleton distribution actor count not 1: {:?}",
1463                source_fragment_distribution
1464            );
1465            assert_eq!(
1466                target_fragment_actors.len(),
1467                1,
1468                "singleton distribution actor count not 1: {:?}",
1469                target_fragment_distribution
1470            );
1471            let (source_actor_id, bitmap) = source_fragment_actors.iter().next().unwrap();
1472            assert_singleton(bitmap);
1473            let (target_actor_id, bitmap) = target_fragment_actors.iter().next().unwrap();
1474            assert_singleton(bitmap);
1475            vec![(*source_actor_id, *target_actor_id)]
1476        }
1477        DistributionType::Hash => {
1478            let mut target_fragment_actor_index: HashMap<_, _> = target_fragment_actors
1479                .iter()
1480                .map(|(actor_id, bitmap)| {
1481                    let bitmap = bitmap
1482                        .as_ref()
1483                        .expect("hash distribution should have bitmap");
1484                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1485                    (first_vnode, (*actor_id, bitmap))
1486                })
1487                .collect();
1488            source_fragment_actors
1489                .iter()
1490                .map(|(source_actor_id, bitmap)| {
1491                    let bitmap = bitmap
1492                        .as_ref()
1493                        .expect("hash distribution should have bitmap");
1494                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1495                    let (target_actor_id, target_bitmap) =
1496                        target_fragment_actor_index.remove(&first_vnode).unwrap_or_else(|| {
1497                            panic!(
1498                                "cannot find matched target actor: {} {:?} {:?} {:?}",
1499                                source_actor_id,
1500                                first_vnode,
1501                                source_fragment_actors,
1502                                target_fragment_actors
1503                            );
1504                        });
1505                    assert_eq!(
1506                        bitmap,
1507                        target_bitmap,
1508                        "cannot find matched target actor due to bitmap mismatch: {} {:?} {:?} {:?}",
1509                        source_actor_id,
1510                        first_vnode,
1511                        source_fragment_actors,
1512                        target_fragment_actors
1513                    );
1514                    (*source_actor_id, target_actor_id)
1515                }).collect()
1516        }
1517    }
1518}
1519
1520/// `get_fragment_mappings` returns the fragment vnode mappings of the given job.
1521pub async fn get_fragment_mappings<C>(
1522    db: &C,
1523    job_id: ObjectId,
1524) -> MetaResult<Vec<PbFragmentWorkerSlotMapping>>
1525where
1526    C: ConnectionTrait,
1527{
1528    let job_actors: Vec<(
1529        FragmentId,
1530        DistributionType,
1531        ActorId,
1532        Option<VnodeBitmap>,
1533        WorkerId,
1534        ActorStatus,
1535    )> = Actor::find()
1536        .select_only()
1537        .columns([
1538            fragment::Column::FragmentId,
1539            fragment::Column::DistributionType,
1540        ])
1541        .columns([
1542            actor::Column::ActorId,
1543            actor::Column::VnodeBitmap,
1544            actor::Column::WorkerId,
1545            actor::Column::Status,
1546        ])
1547        .join(JoinType::InnerJoin, actor::Relation::Fragment.def())
1548        .filter(fragment::Column::JobId.eq(job_id))
1549        .into_tuple()
1550        .all(db)
1551        .await?;
1552
1553    Ok(rebuild_fragment_mapping_from_actors(job_actors))
1554}
1555
1556pub fn rebuild_fragment_mapping(fragment: &SharedFragmentInfo) -> PbFragmentWorkerSlotMapping {
1557    let fragment_worker_slot_mapping = match fragment.distribution_type {
1558        DistributionType::Single => {
1559            let actor = fragment.actors.values().exactly_one().unwrap();
1560            WorkerSlotMapping::new_single(WorkerSlotId::new(actor.worker_id as _, 0))
1561        }
1562        DistributionType::Hash => {
1563            let actor_bitmaps: HashMap<_, _> = fragment
1564                .actors
1565                .iter()
1566                .map(|(actor_id, actor_info)| {
1567                    let vnode_bitmap = actor_info
1568                        .vnode_bitmap
1569                        .as_ref()
1570                        .cloned()
1571                        .expect("actor bitmap shouldn't be none in hash fragment");
1572
1573                    (*actor_id as hash::ActorId, vnode_bitmap)
1574                })
1575                .collect();
1576
1577            let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
1578
1579            let actor_locations = fragment
1580                .actors
1581                .iter()
1582                .map(|(actor_id, actor_info)| {
1583                    (*actor_id as hash::ActorId, actor_info.worker_id as u32)
1584                })
1585                .collect();
1586
1587            actor_mapping.to_worker_slot(&actor_locations)
1588        }
1589    };
1590
1591    PbFragmentWorkerSlotMapping {
1592        fragment_id: fragment.fragment_id,
1593        mapping: Some(fragment_worker_slot_mapping.to_protobuf()),
1594    }
1595}
1596
1597pub fn rebuild_fragment_mapping_from_actors(
1598    job_actors: Vec<(
1599        FragmentId,
1600        DistributionType,
1601        ActorId,
1602        Option<VnodeBitmap>,
1603        WorkerId,
1604        ActorStatus,
1605    )>,
1606) -> Vec<FragmentWorkerSlotMapping> {
1607    let mut all_actor_locations = HashMap::new();
1608    let mut actor_bitmaps = HashMap::new();
1609    let mut fragment_actors = HashMap::new();
1610    let mut fragment_dist = HashMap::new();
1611
1612    for (fragment_id, dist, actor_id, bitmap, worker_id, actor_status) in job_actors {
1613        if actor_status == ActorStatus::Inactive {
1614            continue;
1615        }
1616
1617        all_actor_locations
1618            .entry(fragment_id)
1619            .or_insert(HashMap::new())
1620            .insert(actor_id as hash::ActorId, worker_id as u32);
1621        actor_bitmaps.insert(actor_id, bitmap);
1622        fragment_actors
1623            .entry(fragment_id)
1624            .or_insert_with(Vec::new)
1625            .push(actor_id);
1626        fragment_dist.insert(fragment_id, dist);
1627    }
1628
1629    let mut result = vec![];
1630    for (fragment_id, dist) in fragment_dist {
1631        let mut actor_locations = all_actor_locations.remove(&fragment_id).unwrap();
1632        let fragment_worker_slot_mapping = match dist {
1633            DistributionType::Single => {
1634                let actor = fragment_actors
1635                    .remove(&fragment_id)
1636                    .unwrap()
1637                    .into_iter()
1638                    .exactly_one()
1639                    .unwrap() as hash::ActorId;
1640                let actor_location = actor_locations.remove(&actor).unwrap();
1641
1642                WorkerSlotMapping::new_single(WorkerSlotId::new(actor_location, 0))
1643            }
1644            DistributionType::Hash => {
1645                let actors = fragment_actors.remove(&fragment_id).unwrap();
1646
1647                let all_actor_bitmaps: HashMap<_, _> = actors
1648                    .iter()
1649                    .map(|actor_id| {
1650                        let vnode_bitmap = actor_bitmaps
1651                            .remove(actor_id)
1652                            .flatten()
1653                            .expect("actor bitmap shouldn't be none in hash fragment");
1654
1655                        let bitmap = Bitmap::from(&vnode_bitmap.to_protobuf());
1656                        (*actor_id as hash::ActorId, bitmap)
1657                    })
1658                    .collect();
1659
1660                let actor_mapping = ActorMapping::from_bitmaps(&all_actor_bitmaps);
1661
1662                actor_mapping.to_worker_slot(&actor_locations)
1663            }
1664        };
1665
1666        result.push(PbFragmentWorkerSlotMapping {
1667            fragment_id: fragment_id as u32,
1668            mapping: Some(fragment_worker_slot_mapping.to_protobuf()),
1669        })
1670    }
1671    result
1672}
1673
1674/// `get_fragment_actor_ids` returns the fragment actor ids of the given fragments.
1675pub async fn get_fragment_actor_ids<C>(
1676    db: &C,
1677    fragment_ids: Vec<FragmentId>,
1678) -> MetaResult<HashMap<FragmentId, Vec<ActorId>>>
1679where
1680    C: ConnectionTrait,
1681{
1682    let fragment_actors: Vec<(FragmentId, ActorId)> = Actor::find()
1683        .select_only()
1684        .columns([actor::Column::FragmentId, actor::Column::ActorId])
1685        .filter(actor::Column::FragmentId.is_in(fragment_ids))
1686        .into_tuple()
1687        .all(db)
1688        .await?;
1689
1690    Ok(fragment_actors.into_iter().into_group_map())
1691}
1692
1693/// For the given streaming jobs, returns
1694/// - All source fragments
1695/// - All sink fragments
1696/// - All actors
1697/// - All fragments
1698pub async fn get_fragments_for_jobs<C>(
1699    db: &C,
1700    streaming_jobs: Vec<ObjectId>,
1701) -> MetaResult<(
1702    HashMap<SourceId, BTreeSet<FragmentId>>,
1703    HashSet<FragmentId>,
1704    HashSet<ActorId>,
1705    HashSet<FragmentId>,
1706)>
1707where
1708    C: ConnectionTrait,
1709{
1710    if streaming_jobs.is_empty() {
1711        return Ok((
1712            HashMap::default(),
1713            HashSet::default(),
1714            HashSet::default(),
1715            HashSet::default(),
1716        ));
1717    }
1718
1719    let fragments: Vec<(FragmentId, i32, StreamNode)> = Fragment::find()
1720        .select_only()
1721        .columns([
1722            fragment::Column::FragmentId,
1723            fragment::Column::FragmentTypeMask,
1724            fragment::Column::StreamNode,
1725        ])
1726        .filter(fragment::Column::JobId.is_in(streaming_jobs))
1727        .into_tuple()
1728        .all(db)
1729        .await?;
1730    let actors: Vec<ActorId> = Actor::find()
1731        .select_only()
1732        .column(actor::Column::ActorId)
1733        .filter(
1734            actor::Column::FragmentId.is_in(fragments.iter().map(|(id, _, _)| *id).collect_vec()),
1735        )
1736        .into_tuple()
1737        .all(db)
1738        .await?;
1739
1740    let fragment_ids = fragments
1741        .iter()
1742        .map(|(fragment_id, _, _)| *fragment_id)
1743        .collect();
1744
1745    let mut source_fragment_ids: HashMap<SourceId, BTreeSet<FragmentId>> = HashMap::new();
1746    let mut sink_fragment_ids: HashSet<FragmentId> = HashSet::new();
1747    for (fragment_id, mask, stream_node) in fragments {
1748        if FragmentTypeMask::from(mask).contains(FragmentTypeFlag::Source)
1749            && let Some(source_id) = stream_node.to_protobuf().find_stream_source()
1750        {
1751            source_fragment_ids
1752                .entry(source_id as _)
1753                .or_default()
1754                .insert(fragment_id);
1755        }
1756        if FragmentTypeMask::from(mask).contains(FragmentTypeFlag::Sink) {
1757            sink_fragment_ids.insert(fragment_id);
1758        }
1759    }
1760
1761    Ok((
1762        source_fragment_ids,
1763        sink_fragment_ids,
1764        actors.into_iter().collect(),
1765        fragment_ids,
1766    ))
1767}
1768
1769/// Build a object group for notifying the deletion of the given objects.
1770///
1771/// Note that only id fields are filled in the object info, as the arguments are partial objects.
1772/// As a result, the returned notification info should only be used for deletion.
1773pub(crate) fn build_object_group_for_delete(
1774    partial_objects: Vec<PartialObject>,
1775) -> NotificationInfo {
1776    let mut objects = vec![];
1777    for obj in partial_objects {
1778        match obj.obj_type {
1779            ObjectType::Database => objects.push(PbObject {
1780                object_info: Some(PbObjectInfo::Database(PbDatabase {
1781                    id: obj.oid as _,
1782                    ..Default::default()
1783                })),
1784            }),
1785            ObjectType::Schema => objects.push(PbObject {
1786                object_info: Some(PbObjectInfo::Schema(PbSchema {
1787                    id: obj.oid as _,
1788                    database_id: obj.database_id.unwrap() as _,
1789                    ..Default::default()
1790                })),
1791            }),
1792            ObjectType::Table => objects.push(PbObject {
1793                object_info: Some(PbObjectInfo::Table(PbTable {
1794                    id: obj.oid as _,
1795                    schema_id: obj.schema_id.unwrap() as _,
1796                    database_id: obj.database_id.unwrap() as _,
1797                    ..Default::default()
1798                })),
1799            }),
1800            ObjectType::Source => objects.push(PbObject {
1801                object_info: Some(PbObjectInfo::Source(PbSource {
1802                    id: obj.oid as _,
1803                    schema_id: obj.schema_id.unwrap() as _,
1804                    database_id: obj.database_id.unwrap() as _,
1805                    ..Default::default()
1806                })),
1807            }),
1808            ObjectType::Sink => objects.push(PbObject {
1809                object_info: Some(PbObjectInfo::Sink(PbSink {
1810                    id: obj.oid as _,
1811                    schema_id: obj.schema_id.unwrap() as _,
1812                    database_id: obj.database_id.unwrap() as _,
1813                    ..Default::default()
1814                })),
1815            }),
1816            ObjectType::Subscription => objects.push(PbObject {
1817                object_info: Some(PbObjectInfo::Subscription(PbSubscription {
1818                    id: obj.oid as _,
1819                    schema_id: obj.schema_id.unwrap() as _,
1820                    database_id: obj.database_id.unwrap() as _,
1821                    ..Default::default()
1822                })),
1823            }),
1824            ObjectType::View => objects.push(PbObject {
1825                object_info: Some(PbObjectInfo::View(PbView {
1826                    id: obj.oid as _,
1827                    schema_id: obj.schema_id.unwrap() as _,
1828                    database_id: obj.database_id.unwrap() as _,
1829                    ..Default::default()
1830                })),
1831            }),
1832            ObjectType::Index => {
1833                objects.push(PbObject {
1834                    object_info: Some(PbObjectInfo::Index(PbIndex {
1835                        id: obj.oid as _,
1836                        schema_id: obj.schema_id.unwrap() as _,
1837                        database_id: obj.database_id.unwrap() as _,
1838                        ..Default::default()
1839                    })),
1840                });
1841                objects.push(PbObject {
1842                    object_info: Some(PbObjectInfo::Table(PbTable {
1843                        id: obj.oid as _,
1844                        schema_id: obj.schema_id.unwrap() as _,
1845                        database_id: obj.database_id.unwrap() as _,
1846                        ..Default::default()
1847                    })),
1848                });
1849            }
1850            ObjectType::Function => objects.push(PbObject {
1851                object_info: Some(PbObjectInfo::Function(PbFunction {
1852                    id: obj.oid as _,
1853                    schema_id: obj.schema_id.unwrap() as _,
1854                    database_id: obj.database_id.unwrap() as _,
1855                    ..Default::default()
1856                })),
1857            }),
1858            ObjectType::Connection => objects.push(PbObject {
1859                object_info: Some(PbObjectInfo::Connection(PbConnection {
1860                    id: obj.oid as _,
1861                    schema_id: obj.schema_id.unwrap() as _,
1862                    database_id: obj.database_id.unwrap() as _,
1863                    ..Default::default()
1864                })),
1865            }),
1866            ObjectType::Secret => objects.push(PbObject {
1867                object_info: Some(PbObjectInfo::Secret(PbSecret {
1868                    id: obj.oid as _,
1869                    schema_id: obj.schema_id.unwrap() as _,
1870                    database_id: obj.database_id.unwrap() as _,
1871                    ..Default::default()
1872                })),
1873            }),
1874        }
1875    }
1876    NotificationInfo::ObjectGroup(PbObjectGroup { objects })
1877}
1878
1879pub fn extract_external_table_name_from_definition(table_definition: &str) -> Option<String> {
1880    let [mut definition]: [_; 1] = Parser::parse_sql(table_definition)
1881        .context("unable to parse table definition")
1882        .inspect_err(|e| {
1883            tracing::error!(
1884                target: "auto_schema_change",
1885                error = %e.as_report(),
1886                "failed to parse table definition")
1887        })
1888        .unwrap()
1889        .try_into()
1890        .unwrap();
1891    if let SqlStatement::CreateTable { cdc_table_info, .. } = &mut definition {
1892        cdc_table_info
1893            .clone()
1894            .map(|cdc_table_info| cdc_table_info.external_table_name)
1895    } else {
1896        None
1897    }
1898}
1899
1900/// `rename_relation` renames the target relation and its definition,
1901/// it commits the changes to the transaction and returns the updated relations and the old name.
1902pub async fn rename_relation(
1903    txn: &DatabaseTransaction,
1904    object_type: ObjectType,
1905    object_id: ObjectId,
1906    object_name: &str,
1907) -> MetaResult<(Vec<PbObject>, String)> {
1908    use sea_orm::ActiveModelTrait;
1909
1910    use crate::controller::rename::alter_relation_rename;
1911
1912    let mut to_update_relations = vec![];
1913    // rename relation.
1914    macro_rules! rename_relation {
1915        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
1916            let (mut relation, obj) = $entity::find_by_id($object_id)
1917                .find_also_related(Object)
1918                .one(txn)
1919                .await?
1920                .unwrap();
1921            let obj = obj.unwrap();
1922            let old_name = relation.name.clone();
1923            relation.name = object_name.into();
1924            if obj.obj_type != ObjectType::View {
1925                relation.definition = alter_relation_rename(&relation.definition, object_name);
1926            }
1927            let active_model = $table::ActiveModel {
1928                $identity: Set(relation.$identity),
1929                name: Set(object_name.into()),
1930                definition: Set(relation.definition.clone()),
1931                ..Default::default()
1932            };
1933            active_model.update(txn).await?;
1934            to_update_relations.push(PbObject {
1935                object_info: Some(PbObjectInfo::$entity(ObjectModel(relation, obj).into())),
1936            });
1937            old_name
1938        }};
1939    }
1940    // TODO: check is there any thing to change for shared source?
1941    let old_name = match object_type {
1942        ObjectType::Table => {
1943            let associated_source_id: Option<SourceId> = Source::find()
1944                .select_only()
1945                .column(source::Column::SourceId)
1946                .filter(source::Column::OptionalAssociatedTableId.eq(object_id))
1947                .into_tuple()
1948                .one(txn)
1949                .await?;
1950            if let Some(source_id) = associated_source_id {
1951                rename_relation!(Source, source, source_id, source_id);
1952            }
1953            rename_relation!(Table, table, table_id, object_id)
1954        }
1955        ObjectType::Source => rename_relation!(Source, source, source_id, object_id),
1956        ObjectType::Sink => rename_relation!(Sink, sink, sink_id, object_id),
1957        ObjectType::Subscription => {
1958            rename_relation!(Subscription, subscription, subscription_id, object_id)
1959        }
1960        ObjectType::View => rename_relation!(View, view, view_id, object_id),
1961        ObjectType::Index => {
1962            let (mut index, obj) = Index::find_by_id(object_id)
1963                .find_also_related(Object)
1964                .one(txn)
1965                .await?
1966                .unwrap();
1967            index.name = object_name.into();
1968            let index_table_id = index.index_table_id;
1969            let old_name = rename_relation!(Table, table, table_id, index_table_id);
1970
1971            // the name of index and its associated table is the same.
1972            let active_model = index::ActiveModel {
1973                index_id: sea_orm::ActiveValue::Set(index.index_id),
1974                name: sea_orm::ActiveValue::Set(object_name.into()),
1975                ..Default::default()
1976            };
1977            active_model.update(txn).await?;
1978            to_update_relations.push(PbObject {
1979                object_info: Some(PbObjectInfo::Index(ObjectModel(index, obj.unwrap()).into())),
1980            });
1981            old_name
1982        }
1983        _ => unreachable!("only relation name can be altered."),
1984    };
1985
1986    Ok((to_update_relations, old_name))
1987}
1988
1989pub async fn get_database_resource_group<C>(txn: &C, database_id: ObjectId) -> MetaResult<String>
1990where
1991    C: ConnectionTrait,
1992{
1993    let database_resource_group: Option<String> = Database::find_by_id(database_id)
1994        .select_only()
1995        .column(database::Column::ResourceGroup)
1996        .into_tuple()
1997        .one(txn)
1998        .await?
1999        .ok_or_else(|| MetaError::catalog_id_not_found("database", database_id))?;
2000
2001    Ok(database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned()))
2002}
2003
2004pub async fn get_existing_job_resource_group<C>(
2005    txn: &C,
2006    streaming_job_id: ObjectId,
2007) -> MetaResult<String>
2008where
2009    C: ConnectionTrait,
2010{
2011    let (job_specific_resource_group, database_resource_group): (Option<String>, Option<String>) =
2012        StreamingJob::find_by_id(streaming_job_id)
2013            .select_only()
2014            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
2015            .join(JoinType::InnerJoin, object::Relation::Database2.def())
2016            .column(streaming_job::Column::SpecificResourceGroup)
2017            .column(database::Column::ResourceGroup)
2018            .into_tuple()
2019            .one(txn)
2020            .await?
2021            .ok_or_else(|| MetaError::catalog_id_not_found("streaming job", streaming_job_id))?;
2022
2023    Ok(job_specific_resource_group.unwrap_or_else(|| {
2024        database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned())
2025    }))
2026}
2027
2028pub fn filter_workers_by_resource_group(
2029    workers: &HashMap<u32, WorkerNode>,
2030    resource_group: &str,
2031) -> BTreeSet<WorkerId> {
2032    workers
2033        .iter()
2034        .filter(|&(_, worker)| {
2035            worker
2036                .resource_group()
2037                .map(|node_label| node_label.as_str() == resource_group)
2038                .unwrap_or(false)
2039        })
2040        .map(|(id, _)| (*id as WorkerId))
2041        .collect()
2042}
2043
2044/// `rename_relation_refer` updates the definition of relations that refer to the target one,
2045/// it commits the changes to the transaction and returns all the updated relations.
2046pub async fn rename_relation_refer(
2047    txn: &DatabaseTransaction,
2048    object_type: ObjectType,
2049    object_id: ObjectId,
2050    object_name: &str,
2051    old_name: &str,
2052) -> MetaResult<Vec<PbObject>> {
2053    use sea_orm::ActiveModelTrait;
2054
2055    use crate::controller::rename::alter_relation_rename_refs;
2056
2057    let mut to_update_relations = vec![];
2058    macro_rules! rename_relation_ref {
2059        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
2060            let (mut relation, obj) = $entity::find_by_id($object_id)
2061                .find_also_related(Object)
2062                .one(txn)
2063                .await?
2064                .unwrap();
2065            relation.definition =
2066                alter_relation_rename_refs(&relation.definition, old_name, object_name);
2067            let active_model = $table::ActiveModel {
2068                $identity: Set(relation.$identity),
2069                definition: Set(relation.definition.clone()),
2070                ..Default::default()
2071            };
2072            active_model.update(txn).await?;
2073            to_update_relations.push(PbObject {
2074                object_info: Some(PbObjectInfo::$entity(
2075                    ObjectModel(relation, obj.unwrap()).into(),
2076                )),
2077            });
2078        }};
2079    }
2080    let mut objs = get_referring_objects(object_id, txn).await?;
2081    if object_type == ObjectType::Table {
2082        let incoming_sinks: Vec<SinkId> = Sink::find()
2083            .select_only()
2084            .column(sink::Column::SinkId)
2085            .filter(sink::Column::TargetTable.eq(object_id))
2086            .into_tuple()
2087            .all(txn)
2088            .await?;
2089
2090        objs.extend(incoming_sinks.into_iter().map(|id| PartialObject {
2091            oid: id,
2092            obj_type: ObjectType::Sink,
2093            schema_id: None,
2094            database_id: None,
2095        }));
2096    }
2097
2098    for obj in objs {
2099        match obj.obj_type {
2100            ObjectType::Table => rename_relation_ref!(Table, table, table_id, obj.oid),
2101            ObjectType::Sink => rename_relation_ref!(Sink, sink, sink_id, obj.oid),
2102            ObjectType::Subscription => {
2103                rename_relation_ref!(Subscription, subscription, subscription_id, obj.oid)
2104            }
2105            ObjectType::View => rename_relation_ref!(View, view, view_id, obj.oid),
2106            ObjectType::Index => {
2107                let index_table_id: Option<TableId> = Index::find_by_id(obj.oid)
2108                    .select_only()
2109                    .column(index::Column::IndexTableId)
2110                    .into_tuple()
2111                    .one(txn)
2112                    .await?;
2113                rename_relation_ref!(Table, table, table_id, index_table_id.unwrap());
2114            }
2115            _ => {
2116                bail!(
2117                    "only the table, sink, subscription, view and index will depend on other objects."
2118                )
2119            }
2120        }
2121    }
2122
2123    Ok(to_update_relations)
2124}
2125
2126/// Validate that subscription can be safely deleted, meeting any of the following conditions:
2127/// 1. The upstream table is not referred to by any cross-db mv.
2128/// 2. After deleting the subscription, the upstream table still has at least one subscription.
2129pub async fn validate_subscription_deletion<C>(txn: &C, subscription_id: ObjectId) -> MetaResult<()>
2130where
2131    C: ConnectionTrait,
2132{
2133    let upstream_table_id: ObjectId = Subscription::find_by_id(subscription_id)
2134        .select_only()
2135        .column(subscription::Column::DependentTableId)
2136        .into_tuple()
2137        .one(txn)
2138        .await?
2139        .ok_or_else(|| MetaError::catalog_id_not_found("subscription", subscription_id))?;
2140
2141    let cnt = Subscription::find()
2142        .filter(subscription::Column::DependentTableId.eq(upstream_table_id))
2143        .count(txn)
2144        .await?;
2145    if cnt > 1 {
2146        // Ensure that at least one subscription is remained for the upstream table
2147        // once the subscription is dropped.
2148        return Ok(());
2149    }
2150
2151    // Ensure that the upstream table is not referred by any cross-db mv.
2152    let obj_alias = Alias::new("o1");
2153    let used_by_alias = Alias::new("o2");
2154    let count = ObjectDependency::find()
2155        .join_as(
2156            JoinType::InnerJoin,
2157            object_dependency::Relation::Object2.def(),
2158            obj_alias.clone(),
2159        )
2160        .join_as(
2161            JoinType::InnerJoin,
2162            object_dependency::Relation::Object1.def(),
2163            used_by_alias.clone(),
2164        )
2165        .filter(
2166            object_dependency::Column::Oid
2167                .eq(upstream_table_id)
2168                .and(object_dependency::Column::UsedBy.ne(subscription_id))
2169                .and(
2170                    Expr::col((obj_alias, object::Column::DatabaseId))
2171                        .ne(Expr::col((used_by_alias, object::Column::DatabaseId))),
2172                ),
2173        )
2174        .count(txn)
2175        .await?;
2176
2177    if count != 0 {
2178        return Err(MetaError::permission_denied(format!(
2179            "Referenced by {} cross-db objects.",
2180            count
2181        )));
2182    }
2183
2184    Ok(())
2185}
2186
2187pub async fn fetch_target_fragments<C>(
2188    txn: &C,
2189    src_fragment_id: impl IntoIterator<Item = FragmentId>,
2190) -> MetaResult<HashMap<FragmentId, Vec<FragmentId>>>
2191where
2192    C: ConnectionTrait,
2193{
2194    let source_target_fragments: Vec<(FragmentId, FragmentId)> = FragmentRelation::find()
2195        .select_only()
2196        .columns([
2197            fragment_relation::Column::SourceFragmentId,
2198            fragment_relation::Column::TargetFragmentId,
2199        ])
2200        .filter(fragment_relation::Column::SourceFragmentId.is_in(src_fragment_id))
2201        .into_tuple()
2202        .all(txn)
2203        .await?;
2204
2205    let source_target_fragments = source_target_fragments.into_iter().into_group_map();
2206
2207    Ok(source_target_fragments)
2208}
2209
2210pub async fn get_sink_fragment_by_ids<C>(
2211    txn: &C,
2212    sink_ids: Vec<SinkId>,
2213) -> MetaResult<HashMap<SinkId, FragmentId>>
2214where
2215    C: ConnectionTrait,
2216{
2217    let sink_num = sink_ids.len();
2218    let sink_fragment_ids: Vec<(SinkId, FragmentId)> = Fragment::find()
2219        .select_only()
2220        .columns([fragment::Column::JobId, fragment::Column::FragmentId])
2221        .filter(
2222            fragment::Column::JobId
2223                .is_in(sink_ids)
2224                .and(FragmentTypeMask::intersects(FragmentTypeFlag::Sink)),
2225        )
2226        .into_tuple()
2227        .all(txn)
2228        .await?;
2229
2230    if sink_fragment_ids.len() != sink_num {
2231        return Err(anyhow::anyhow!(
2232            "expected exactly one sink fragment for each sink, but got {} fragments for {} sinks",
2233            sink_fragment_ids.len(),
2234            sink_num
2235        )
2236        .into());
2237    }
2238
2239    Ok(sink_fragment_ids.into_iter().collect())
2240}
2241
2242pub async fn has_table_been_migrated<C>(txn: &C, table_id: TableId) -> MetaResult<bool>
2243where
2244    C: ConnectionTrait,
2245{
2246    let mview_fragment: Vec<i32> = Fragment::find()
2247        .select_only()
2248        .column(fragment::Column::FragmentTypeMask)
2249        .filter(
2250            fragment::Column::JobId
2251                .eq(table_id)
2252                .and(FragmentTypeMask::intersects(FragmentTypeFlag::Mview)),
2253        )
2254        .into_tuple()
2255        .all(txn)
2256        .await?;
2257
2258    let mview_fragment_len = mview_fragment.len();
2259    if mview_fragment_len != 1 {
2260        bail!(
2261            "expected exactly one mview fragment for table {}, found {}",
2262            table_id,
2263            mview_fragment_len
2264        );
2265    }
2266
2267    let mview_fragment = mview_fragment.into_iter().next().unwrap();
2268    let migrated =
2269        FragmentTypeMask::from(mview_fragment).contains(FragmentTypeFlag::UpstreamSinkUnion);
2270
2271    Ok(migrated)
2272}
2273
2274pub fn build_select_node_list(
2275    from: &[ColumnCatalog],
2276    to: &[ColumnCatalog],
2277) -> MetaResult<Vec<PbExprNode>> {
2278    let mut exprs = Vec::with_capacity(to.len());
2279    let idx_by_col_id = from
2280        .iter()
2281        .enumerate()
2282        .map(|(idx, col)| (col.column_desc.as_ref().unwrap().column_id, idx))
2283        .collect::<HashMap<_, _>>();
2284
2285    for to_col in to {
2286        let to_col = to_col.column_desc.as_ref().unwrap();
2287        let to_col_type = to_col.column_type.clone();
2288        if let Some(from_idx) = idx_by_col_id.get(&to_col.column_id) {
2289            let from_col_type = from[*from_idx]
2290                .column_desc
2291                .as_ref()
2292                .unwrap()
2293                .column_type
2294                .clone();
2295            if to_col_type != from_col_type {
2296                return Err(anyhow!(
2297                    "Column type mismatch: {:?} != {:?}",
2298                    from_col_type,
2299                    to_col_type
2300                )
2301                .into());
2302            }
2303            exprs.push(PbExprNode {
2304                function_type: expr_node::Type::Unspecified.into(),
2305                return_type: to_col_type,
2306                rex_node: Some(expr_node::RexNode::InputRef(*from_idx as _)),
2307            });
2308        } else {
2309            let to_default_node =
2310                if let Some(GeneratedOrDefaultColumn::DefaultColumn(DefaultColumnDesc {
2311                    expr,
2312                    ..
2313                })) = &to_col.generated_or_default_column
2314                {
2315                    expr.clone().unwrap()
2316                } else {
2317                    let null = Datum::None.to_protobuf();
2318                    PbExprNode {
2319                        function_type: expr_node::Type::Unspecified.into(),
2320                        return_type: to_col_type,
2321                        rex_node: Some(expr_node::RexNode::Constant(null)),
2322                    }
2323                };
2324            exprs.push(to_default_node);
2325        }
2326    }
2327
2328    Ok(exprs)
2329}
2330
2331#[cfg(test)]
2332mod tests {
2333    use super::*;
2334
2335    #[test]
2336    fn test_extract_cdc_table_name() {
2337        let ddl1 = "CREATE TABLE t1 () FROM pg_source TABLE 'public.t1'";
2338        let ddl2 = "CREATE TABLE t2 (v1 int) FROM pg_source TABLE 'mydb.t2'";
2339        assert_eq!(
2340            extract_external_table_name_from_definition(ddl1),
2341            Some("public.t1".into())
2342        );
2343        assert_eq!(
2344            extract_external_table_name_from_definition(ddl2),
2345            Some("mydb.t2".into())
2346        );
2347    }
2348}