risingwave_meta/controller/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::{BTreeSet, HashMap, HashSet};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
23use risingwave_common::hash::{ActorMapping, VnodeBitmapExt, WorkerSlotId, WorkerSlotMapping};
24use risingwave_common::types::Datum;
25use risingwave_common::util::value_encoding::DatumToProtoExt;
26use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
27use risingwave_common::{bail, hash};
28use risingwave_meta_model::actor::ActorStatus;
29use risingwave_meta_model::fragment::DistributionType;
30use risingwave_meta_model::object::ObjectType;
31use risingwave_meta_model::prelude::*;
32use risingwave_meta_model::table::TableType;
33use risingwave_meta_model::user_privilege::Action;
34use risingwave_meta_model::{
35    ActorId, ColumnCatalogArray, DataTypeArray, DatabaseId, DispatcherType, FragmentId, I32Array,
36    JobStatus, ObjectId, PrivilegeId, SchemaId, SinkId, SourceId, StreamNode, StreamSourceInfo,
37    TableId, UserId, VnodeBitmap, WorkerId, actor, connection, database, fragment,
38    fragment_relation, function, index, object, object_dependency, schema, secret, sink, source,
39    streaming_job, subscription, table, user, user_default_privilege, user_privilege, view,
40};
41use risingwave_meta_model_migration::WithQuery;
42use risingwave_pb::catalog::{
43    PbConnection, PbDatabase, PbFunction, PbIndex, PbSchema, PbSecret, PbSink, PbSource,
44    PbSubscription, PbTable, PbView,
45};
46use risingwave_pb::common::WorkerNode;
47use risingwave_pb::expr::{PbExprNode, expr_node};
48use risingwave_pb::meta::object::PbObjectInfo;
49use risingwave_pb::meta::subscribe_response::Info as NotificationInfo;
50use risingwave_pb::meta::{
51    FragmentWorkerSlotMapping, PbFragmentWorkerSlotMapping, PbObject, PbObjectGroup,
52};
53use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn;
54use risingwave_pb::plan_common::{ColumnCatalog, DefaultColumnDesc};
55use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher, PbDispatcherType};
56use risingwave_pb::user::grant_privilege::{PbActionWithGrantOption, PbObject as PbGrantObject};
57use risingwave_pb::user::{PbAction, PbGrantPrivilege, PbUserInfo};
58use risingwave_sqlparser::ast::Statement as SqlStatement;
59use risingwave_sqlparser::parser::Parser;
60use sea_orm::sea_query::{
61    Alias, CommonTableExpression, Expr, Query, QueryStatementBuilder, SelectStatement, UnionType,
62    WithClause,
63};
64use sea_orm::{
65    ColumnTrait, ConnectionTrait, DatabaseTransaction, DerivePartialModel, EntityTrait,
66    FromQueryResult, IntoActiveModel, JoinType, Order, PaginatorTrait, QueryFilter, QuerySelect,
67    RelationTrait, Set, Statement,
68};
69use thiserror_ext::AsReport;
70
71use crate::barrier::SharedFragmentInfo;
72use crate::controller::ObjectModel;
73use crate::controller::fragment::FragmentTypeMaskExt;
74use crate::model::{FragmentActorDispatchers, FragmentDownstreamRelation};
75use crate::{MetaError, MetaResult};
76
77/// This function will construct a query using recursive cte to find all objects[(id, `obj_type`)] that are used by the given object.
78///
79/// # Examples
80///
81/// ```
82/// use risingwave_meta::controller::utils::construct_obj_dependency_query;
83/// use sea_orm::sea_query::*;
84/// use sea_orm::*;
85///
86/// let query = construct_obj_dependency_query(1);
87///
88/// assert_eq!(
89///     query.to_string(MysqlQueryBuilder),
90///     r#"WITH RECURSIVE `used_by_object_ids` (`used_by`) AS (SELECT `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `oid` FROM `object` WHERE `object`.`database_id` = 1 OR `object`.`schema_id` = 1) UNION ALL (SELECT `object_dependency`.`used_by` FROM `object_dependency` INNER JOIN `used_by_object_ids` ON `used_by_object_ids`.`used_by` = `oid`)) SELECT DISTINCT `oid`, `obj_type`, `schema_id`, `database_id` FROM `used_by_object_ids` INNER JOIN `object` ON `used_by_object_ids`.`used_by` = `oid` ORDER BY `oid` DESC"#
91/// );
92/// assert_eq!(
93///     query.to_string(PostgresQueryBuilder),
94///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1) UNION ALL (SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid")) SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
95/// );
96/// assert_eq!(
97///     query.to_string(SqliteQueryBuilder),
98///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1 UNION ALL SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid") SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
99/// );
100/// ```
101pub fn construct_obj_dependency_query(obj_id: ObjectId) -> WithQuery {
102    let cte_alias = Alias::new("used_by_object_ids");
103    let cte_return_alias = Alias::new("used_by");
104
105    let mut base_query = SelectStatement::new()
106        .column(object_dependency::Column::UsedBy)
107        .from(ObjectDependency)
108        .and_where(object_dependency::Column::Oid.eq(obj_id))
109        .to_owned();
110
111    let belonged_obj_query = SelectStatement::new()
112        .column(object::Column::Oid)
113        .from(Object)
114        .and_where(
115            object::Column::DatabaseId
116                .eq(obj_id)
117                .or(object::Column::SchemaId.eq(obj_id)),
118        )
119        .to_owned();
120
121    let cte_referencing = Query::select()
122        .column((ObjectDependency, object_dependency::Column::UsedBy))
123        .from(ObjectDependency)
124        .inner_join(
125            cte_alias.clone(),
126            Expr::col((cte_alias.clone(), cte_return_alias.clone()))
127                .equals(object_dependency::Column::Oid),
128        )
129        .to_owned();
130
131    let mut common_table_expr = CommonTableExpression::new();
132    common_table_expr
133        .query(
134            base_query
135                .union(UnionType::All, belonged_obj_query)
136                .union(UnionType::All, cte_referencing)
137                .to_owned(),
138        )
139        .column(cte_return_alias.clone())
140        .table_name(cte_alias.clone());
141
142    SelectStatement::new()
143        .distinct()
144        .columns([
145            object::Column::Oid,
146            object::Column::ObjType,
147            object::Column::SchemaId,
148            object::Column::DatabaseId,
149        ])
150        .from(cte_alias.clone())
151        .inner_join(
152            Object,
153            Expr::col((cte_alias, cte_return_alias.clone())).equals(object::Column::Oid),
154        )
155        .order_by(object::Column::Oid, Order::Desc)
156        .to_owned()
157        .with(
158            WithClause::new()
159                .recursive(true)
160                .cte(common_table_expr)
161                .to_owned(),
162        )
163}
164
165/// This function will construct a query using recursive cte to find if dependent objects are already relying on the target table.
166///
167/// # Examples
168///
169/// ```
170/// use risingwave_meta::controller::utils::construct_sink_cycle_check_query;
171/// use sea_orm::sea_query::*;
172/// use sea_orm::*;
173///
174/// let query = construct_sink_cycle_check_query(1, vec![2, 3]);
175///
176/// assert_eq!(
177///     query.to_string(MysqlQueryBuilder),
178///     r#"WITH RECURSIVE `used_by_object_ids_with_sink` (`oid`, `used_by`) AS (SELECT `oid`, `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `obj_dependency_with_sink`.`oid`, `obj_dependency_with_sink`.`used_by` FROM (SELECT `oid`, `used_by` FROM `object_dependency` UNION ALL (SELECT `sink_id`, `target_table` FROM `sink` WHERE `sink`.`target_table` IS NOT NULL)) AS `obj_dependency_with_sink` INNER JOIN `used_by_object_ids_with_sink` ON `used_by_object_ids_with_sink`.`used_by` = `obj_dependency_with_sink`.`oid` WHERE `used_by_object_ids_with_sink`.`used_by` <> `used_by_object_ids_with_sink`.`oid`)) SELECT COUNT(`used_by_object_ids_with_sink`.`used_by`) FROM `used_by_object_ids_with_sink` WHERE `used_by_object_ids_with_sink`.`used_by` IN (2, 3)"#
179/// );
180/// assert_eq!(
181///     query.to_string(PostgresQueryBuilder),
182///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL (SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL)) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid")) SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
183/// );
184/// assert_eq!(
185///     query.to_string(SqliteQueryBuilder),
186///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid") SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
187/// );
188/// ```
189pub fn construct_sink_cycle_check_query(
190    target_table: ObjectId,
191    dependent_objects: Vec<ObjectId>,
192) -> WithQuery {
193    let cte_alias = Alias::new("used_by_object_ids_with_sink");
194    let depend_alias = Alias::new("obj_dependency_with_sink");
195
196    let mut base_query = SelectStatement::new()
197        .columns([
198            object_dependency::Column::Oid,
199            object_dependency::Column::UsedBy,
200        ])
201        .from(ObjectDependency)
202        .and_where(object_dependency::Column::Oid.eq(target_table))
203        .to_owned();
204
205    let query_sink_deps = SelectStatement::new()
206        .columns([sink::Column::SinkId, sink::Column::TargetTable])
207        .from(Sink)
208        .and_where(sink::Column::TargetTable.is_not_null())
209        .to_owned();
210
211    let cte_referencing = Query::select()
212        .column((depend_alias.clone(), object_dependency::Column::Oid))
213        .column((depend_alias.clone(), object_dependency::Column::UsedBy))
214        .from_subquery(
215            SelectStatement::new()
216                .columns([
217                    object_dependency::Column::Oid,
218                    object_dependency::Column::UsedBy,
219                ])
220                .from(ObjectDependency)
221                .union(UnionType::All, query_sink_deps)
222                .to_owned(),
223            depend_alias.clone(),
224        )
225        .inner_join(
226            cte_alias.clone(),
227            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).eq(Expr::col((
228                depend_alias.clone(),
229                object_dependency::Column::Oid,
230            ))),
231        )
232        .and_where(
233            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).ne(Expr::col((
234                cte_alias.clone(),
235                object_dependency::Column::Oid,
236            ))),
237        )
238        .to_owned();
239
240    let mut common_table_expr = CommonTableExpression::new();
241    common_table_expr
242        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
243        .columns([
244            object_dependency::Column::Oid,
245            object_dependency::Column::UsedBy,
246        ])
247        .table_name(cte_alias.clone());
248
249    SelectStatement::new()
250        .expr(Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).count())
251        .from(cte_alias.clone())
252        .and_where(
253            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy))
254                .is_in(dependent_objects),
255        )
256        .to_owned()
257        .with(
258            WithClause::new()
259                .recursive(true)
260                .cte(common_table_expr)
261                .to_owned(),
262        )
263}
264
265#[derive(Clone, DerivePartialModel, FromQueryResult, Debug)]
266#[sea_orm(entity = "Object")]
267pub struct PartialObject {
268    pub oid: ObjectId,
269    pub obj_type: ObjectType,
270    pub schema_id: Option<SchemaId>,
271    pub database_id: Option<DatabaseId>,
272}
273
274#[derive(Clone, DerivePartialModel, FromQueryResult)]
275#[sea_orm(entity = "Fragment")]
276pub struct PartialFragmentStateTables {
277    pub fragment_id: FragmentId,
278    pub job_id: ObjectId,
279    pub state_table_ids: I32Array,
280}
281
282#[derive(Clone, DerivePartialModel, FromQueryResult)]
283#[sea_orm(entity = "Actor")]
284pub struct PartialActorLocation {
285    pub actor_id: ActorId,
286    pub fragment_id: FragmentId,
287    pub worker_id: WorkerId,
288    pub status: ActorStatus,
289}
290
291#[derive(FromQueryResult)]
292pub struct FragmentDesc {
293    pub fragment_id: FragmentId,
294    pub job_id: ObjectId,
295    pub fragment_type_mask: i32,
296    pub distribution_type: DistributionType,
297    pub state_table_ids: I32Array,
298    pub parallelism: i64,
299    pub vnode_count: i32,
300    pub stream_node: StreamNode,
301}
302
303/// List all objects that are using the given one in a cascade way. It runs a recursive CTE to find all the dependencies.
304pub async fn get_referring_objects_cascade<C>(
305    obj_id: ObjectId,
306    db: &C,
307) -> MetaResult<Vec<PartialObject>>
308where
309    C: ConnectionTrait,
310{
311    let query = construct_obj_dependency_query(obj_id);
312    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
313    let objects = PartialObject::find_by_statement(Statement::from_sql_and_values(
314        db.get_database_backend(),
315        sql,
316        values,
317    ))
318    .all(db)
319    .await?;
320    Ok(objects)
321}
322
323/// Check if create a sink with given dependent objects into the target table will cause a cycle, return true if it will.
324pub async fn check_sink_into_table_cycle<C>(
325    target_table: ObjectId,
326    dependent_objs: Vec<ObjectId>,
327    db: &C,
328) -> MetaResult<bool>
329where
330    C: ConnectionTrait,
331{
332    if dependent_objs.is_empty() {
333        return Ok(false);
334    }
335
336    // special check for self referencing
337    if dependent_objs.contains(&target_table) {
338        return Ok(true);
339    }
340
341    let query = construct_sink_cycle_check_query(target_table, dependent_objs);
342    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
343
344    let res = db
345        .query_one(Statement::from_sql_and_values(
346            db.get_database_backend(),
347            sql,
348            values,
349        ))
350        .await?
351        .unwrap();
352
353    let cnt: i64 = res.try_get_by(0)?;
354
355    Ok(cnt != 0)
356}
357
358/// `ensure_object_id` ensures the existence of target object in the cluster.
359pub async fn ensure_object_id<C>(
360    object_type: ObjectType,
361    obj_id: ObjectId,
362    db: &C,
363) -> MetaResult<()>
364where
365    C: ConnectionTrait,
366{
367    let count = Object::find_by_id(obj_id).count(db).await?;
368    if count == 0 {
369        return Err(MetaError::catalog_id_not_found(
370            object_type.as_str(),
371            obj_id,
372        ));
373    }
374    Ok(())
375}
376
377pub async fn ensure_job_not_canceled<C>(job_id: ObjectId, db: &C) -> MetaResult<()>
378where
379    C: ConnectionTrait,
380{
381    let count = Object::find_by_id(job_id).count(db).await?;
382    if count == 0 {
383        return Err(MetaError::cancelled(format!(
384            "job {} might be cancelled manually or by recovery",
385            job_id
386        )));
387    }
388    Ok(())
389}
390
391/// `ensure_user_id` ensures the existence of target user in the cluster.
392pub async fn ensure_user_id<C>(user_id: UserId, db: &C) -> MetaResult<()>
393where
394    C: ConnectionTrait,
395{
396    let count = User::find_by_id(user_id).count(db).await?;
397    if count == 0 {
398        return Err(anyhow!("user {} was concurrently dropped", user_id).into());
399    }
400    Ok(())
401}
402
403/// `check_database_name_duplicate` checks whether the database name is already used in the cluster.
404pub async fn check_database_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
405where
406    C: ConnectionTrait,
407{
408    let count = Database::find()
409        .filter(database::Column::Name.eq(name))
410        .count(db)
411        .await?;
412    if count > 0 {
413        assert_eq!(count, 1);
414        return Err(MetaError::catalog_duplicated("database", name));
415    }
416    Ok(())
417}
418
419/// `check_function_signature_duplicate` checks whether the function name and its signature is already used in the target namespace.
420pub async fn check_function_signature_duplicate<C>(
421    pb_function: &PbFunction,
422    db: &C,
423) -> MetaResult<()>
424where
425    C: ConnectionTrait,
426{
427    let count = Function::find()
428        .inner_join(Object)
429        .filter(
430            object::Column::DatabaseId
431                .eq(pb_function.database_id as DatabaseId)
432                .and(object::Column::SchemaId.eq(pb_function.schema_id as SchemaId))
433                .and(function::Column::Name.eq(&pb_function.name))
434                .and(
435                    function::Column::ArgTypes
436                        .eq(DataTypeArray::from(pb_function.arg_types.clone())),
437                ),
438        )
439        .count(db)
440        .await?;
441    if count > 0 {
442        assert_eq!(count, 1);
443        return Err(MetaError::catalog_duplicated("function", &pb_function.name));
444    }
445    Ok(())
446}
447
448/// `check_connection_name_duplicate` checks whether the connection name is already used in the target namespace.
449pub async fn check_connection_name_duplicate<C>(
450    pb_connection: &PbConnection,
451    db: &C,
452) -> MetaResult<()>
453where
454    C: ConnectionTrait,
455{
456    let count = Connection::find()
457        .inner_join(Object)
458        .filter(
459            object::Column::DatabaseId
460                .eq(pb_connection.database_id as DatabaseId)
461                .and(object::Column::SchemaId.eq(pb_connection.schema_id as SchemaId))
462                .and(connection::Column::Name.eq(&pb_connection.name)),
463        )
464        .count(db)
465        .await?;
466    if count > 0 {
467        assert_eq!(count, 1);
468        return Err(MetaError::catalog_duplicated(
469            "connection",
470            &pb_connection.name,
471        ));
472    }
473    Ok(())
474}
475
476pub async fn check_secret_name_duplicate<C>(pb_secret: &PbSecret, db: &C) -> MetaResult<()>
477where
478    C: ConnectionTrait,
479{
480    let count = Secret::find()
481        .inner_join(Object)
482        .filter(
483            object::Column::DatabaseId
484                .eq(pb_secret.database_id as DatabaseId)
485                .and(object::Column::SchemaId.eq(pb_secret.schema_id as SchemaId))
486                .and(secret::Column::Name.eq(&pb_secret.name)),
487        )
488        .count(db)
489        .await?;
490    if count > 0 {
491        assert_eq!(count, 1);
492        return Err(MetaError::catalog_duplicated("secret", &pb_secret.name));
493    }
494    Ok(())
495}
496
497pub async fn check_subscription_name_duplicate<C>(
498    pb_subscription: &PbSubscription,
499    db: &C,
500) -> MetaResult<()>
501where
502    C: ConnectionTrait,
503{
504    let count = Subscription::find()
505        .inner_join(Object)
506        .filter(
507            object::Column::DatabaseId
508                .eq(pb_subscription.database_id as DatabaseId)
509                .and(object::Column::SchemaId.eq(pb_subscription.schema_id as SchemaId))
510                .and(subscription::Column::Name.eq(&pb_subscription.name)),
511        )
512        .count(db)
513        .await?;
514    if count > 0 {
515        assert_eq!(count, 1);
516        return Err(MetaError::catalog_duplicated(
517            "subscription",
518            &pb_subscription.name,
519        ));
520    }
521    Ok(())
522}
523
524/// `check_user_name_duplicate` checks whether the user is already existed in the cluster.
525pub async fn check_user_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
526where
527    C: ConnectionTrait,
528{
529    let count = User::find()
530        .filter(user::Column::Name.eq(name))
531        .count(db)
532        .await?;
533    if count > 0 {
534        assert_eq!(count, 1);
535        return Err(MetaError::catalog_duplicated("user", name));
536    }
537    Ok(())
538}
539
540/// `check_relation_name_duplicate` checks whether the relation name is already used in the target namespace.
541pub async fn check_relation_name_duplicate<C>(
542    name: &str,
543    database_id: DatabaseId,
544    schema_id: SchemaId,
545    db: &C,
546) -> MetaResult<()>
547where
548    C: ConnectionTrait,
549{
550    macro_rules! check_duplicated {
551        ($obj_type:expr, $entity:ident, $table:ident) => {
552            let object_id = Object::find()
553                .select_only()
554                .column(object::Column::Oid)
555                .inner_join($entity)
556                .filter(
557                    object::Column::DatabaseId
558                        .eq(Some(database_id))
559                        .and(object::Column::SchemaId.eq(Some(schema_id)))
560                        .and($table::Column::Name.eq(name)),
561                )
562                .into_tuple::<ObjectId>()
563                .one(db)
564                .await?;
565            if let Some(oid) = object_id {
566                let check_creation = if $obj_type == ObjectType::View {
567                    false
568                } else if $obj_type == ObjectType::Source {
569                    let source_info = Source::find_by_id(oid)
570                        .select_only()
571                        .column(source::Column::SourceInfo)
572                        .into_tuple::<Option<StreamSourceInfo>>()
573                        .one(db)
574                        .await?
575                        .unwrap();
576                    source_info.map_or(false, |info| info.to_protobuf().is_shared())
577                } else {
578                    true
579                };
580                return if check_creation
581                    && !matches!(
582                        StreamingJob::find_by_id(oid)
583                            .select_only()
584                            .column(streaming_job::Column::JobStatus)
585                            .into_tuple::<JobStatus>()
586                            .one(db)
587                            .await?,
588                        Some(JobStatus::Created)
589                    ) {
590                    Err(MetaError::catalog_under_creation(
591                        $obj_type.as_str(),
592                        name,
593                        oid,
594                    ))
595                } else {
596                    Err(MetaError::catalog_duplicated($obj_type.as_str(), name))
597                };
598            }
599        };
600    }
601    check_duplicated!(ObjectType::Table, Table, table);
602    check_duplicated!(ObjectType::Source, Source, source);
603    check_duplicated!(ObjectType::Sink, Sink, sink);
604    check_duplicated!(ObjectType::Index, Index, index);
605    check_duplicated!(ObjectType::View, View, view);
606
607    Ok(())
608}
609
610/// `check_schema_name_duplicate` checks whether the schema name is already used in the target database.
611pub async fn check_schema_name_duplicate<C>(
612    name: &str,
613    database_id: DatabaseId,
614    db: &C,
615) -> MetaResult<()>
616where
617    C: ConnectionTrait,
618{
619    let count = Object::find()
620        .inner_join(Schema)
621        .filter(
622            object::Column::ObjType
623                .eq(ObjectType::Schema)
624                .and(object::Column::DatabaseId.eq(Some(database_id)))
625                .and(schema::Column::Name.eq(name)),
626        )
627        .count(db)
628        .await?;
629    if count != 0 {
630        return Err(MetaError::catalog_duplicated("schema", name));
631    }
632
633    Ok(())
634}
635
636/// `check_object_refer_for_drop` checks whether the object is used by other objects except indexes.
637/// It returns an error that contains the details of the referring objects if it is used by others.
638pub async fn check_object_refer_for_drop<C>(
639    object_type: ObjectType,
640    object_id: ObjectId,
641    db: &C,
642) -> MetaResult<()>
643where
644    C: ConnectionTrait,
645{
646    // Ignore indexes.
647    let count = if object_type == ObjectType::Table {
648        ObjectDependency::find()
649            .join(
650                JoinType::InnerJoin,
651                object_dependency::Relation::Object1.def(),
652            )
653            .filter(
654                object_dependency::Column::Oid
655                    .eq(object_id)
656                    .and(object::Column::ObjType.ne(ObjectType::Index)),
657            )
658            .count(db)
659            .await?
660    } else {
661        ObjectDependency::find()
662            .filter(object_dependency::Column::Oid.eq(object_id))
663            .count(db)
664            .await?
665    };
666    if count != 0 {
667        // find the name of all objects that are using the given one.
668        let referring_objects = get_referring_objects(object_id, db).await?;
669        let referring_objs_map = referring_objects
670            .into_iter()
671            .filter(|o| o.obj_type != ObjectType::Index)
672            .into_group_map_by(|o| o.obj_type);
673        let mut details = vec![];
674        for (obj_type, objs) in referring_objs_map {
675            match obj_type {
676                ObjectType::Table => {
677                    let tables: Vec<(String, String)> = Object::find()
678                        .join(JoinType::InnerJoin, object::Relation::Table.def())
679                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
680                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
681                        .select_only()
682                        .column(schema::Column::Name)
683                        .column(table::Column::Name)
684                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
685                        .into_tuple()
686                        .all(db)
687                        .await?;
688                    details.extend(tables.into_iter().map(|(schema_name, table_name)| {
689                        format!(
690                            "materialized view {}.{} depends on it",
691                            schema_name, table_name
692                        )
693                    }));
694                }
695                ObjectType::Sink => {
696                    let sinks: Vec<(String, String)> = Object::find()
697                        .join(JoinType::InnerJoin, object::Relation::Sink.def())
698                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
699                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
700                        .select_only()
701                        .column(schema::Column::Name)
702                        .column(sink::Column::Name)
703                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
704                        .into_tuple()
705                        .all(db)
706                        .await?;
707                    details.extend(sinks.into_iter().map(|(schema_name, sink_name)| {
708                        format!("sink {}.{} depends on it", schema_name, sink_name)
709                    }));
710                }
711                ObjectType::View => {
712                    let views: Vec<(String, String)> = Object::find()
713                        .join(JoinType::InnerJoin, object::Relation::View.def())
714                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
715                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
716                        .select_only()
717                        .column(schema::Column::Name)
718                        .column(view::Column::Name)
719                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
720                        .into_tuple()
721                        .all(db)
722                        .await?;
723                    details.extend(views.into_iter().map(|(schema_name, view_name)| {
724                        format!("view {}.{} depends on it", schema_name, view_name)
725                    }));
726                }
727                ObjectType::Subscription => {
728                    let subscriptions: Vec<(String, String)> = Object::find()
729                        .join(JoinType::InnerJoin, object::Relation::Subscription.def())
730                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
731                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
732                        .select_only()
733                        .column(schema::Column::Name)
734                        .column(subscription::Column::Name)
735                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
736                        .into_tuple()
737                        .all(db)
738                        .await?;
739                    details.extend(subscriptions.into_iter().map(
740                        |(schema_name, subscription_name)| {
741                            format!(
742                                "subscription {}.{} depends on it",
743                                schema_name, subscription_name
744                            )
745                        },
746                    ));
747                }
748                ObjectType::Source => {
749                    let sources: Vec<(String, String)> = Object::find()
750                        .join(JoinType::InnerJoin, object::Relation::Source.def())
751                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
752                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
753                        .select_only()
754                        .column(schema::Column::Name)
755                        .column(source::Column::Name)
756                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
757                        .into_tuple()
758                        .all(db)
759                        .await?;
760                    details.extend(sources.into_iter().map(|(schema_name, view_name)| {
761                        format!("source {}.{} depends on it", schema_name, view_name)
762                    }));
763                }
764                ObjectType::Connection => {
765                    let connections: Vec<(String, String)> = Object::find()
766                        .join(JoinType::InnerJoin, object::Relation::Connection.def())
767                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
768                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
769                        .select_only()
770                        .column(schema::Column::Name)
771                        .column(connection::Column::Name)
772                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
773                        .into_tuple()
774                        .all(db)
775                        .await?;
776                    details.extend(connections.into_iter().map(|(schema_name, view_name)| {
777                        format!("connection {}.{} depends on it", schema_name, view_name)
778                    }));
779                }
780                // only the table, source, sink, subscription, view, connection and index will depend on other objects.
781                _ => bail!("unexpected referring object type: {}", obj_type.as_str()),
782            }
783        }
784
785        return Err(MetaError::permission_denied(format!(
786            "{} used by {} other objects. \nDETAIL: {}\n\
787            {}",
788            object_type.as_str(),
789            count,
790            details.join("\n"),
791            match object_type {
792                ObjectType::Function | ObjectType::Connection | ObjectType::Secret =>
793                    "HINT: DROP the dependent objects first.",
794                ObjectType::Database | ObjectType::Schema => unreachable!(),
795                _ => "HINT:  Use DROP ... CASCADE to drop the dependent objects too.",
796            }
797        )));
798    }
799    Ok(())
800}
801
802/// List all objects that are using the given one.
803pub async fn get_referring_objects<C>(object_id: ObjectId, db: &C) -> MetaResult<Vec<PartialObject>>
804where
805    C: ConnectionTrait,
806{
807    let objs = ObjectDependency::find()
808        .filter(object_dependency::Column::Oid.eq(object_id))
809        .join(
810            JoinType::InnerJoin,
811            object_dependency::Relation::Object1.def(),
812        )
813        .into_partial_model()
814        .all(db)
815        .await?;
816
817    Ok(objs)
818}
819
820/// `ensure_schema_empty` ensures that the schema is empty, used by `DROP SCHEMA`.
821pub async fn ensure_schema_empty<C>(schema_id: SchemaId, db: &C) -> MetaResult<()>
822where
823    C: ConnectionTrait,
824{
825    let count = Object::find()
826        .filter(object::Column::SchemaId.eq(Some(schema_id)))
827        .count(db)
828        .await?;
829    if count != 0 {
830        return Err(MetaError::permission_denied("schema is not empty"));
831    }
832
833    Ok(())
834}
835
836/// `list_user_info_by_ids` lists all users' info by their ids.
837pub async fn list_user_info_by_ids<C>(
838    user_ids: impl IntoIterator<Item = UserId>,
839    db: &C,
840) -> MetaResult<Vec<PbUserInfo>>
841where
842    C: ConnectionTrait,
843{
844    let mut user_infos = vec![];
845    for user_id in user_ids {
846        let user = User::find_by_id(user_id)
847            .one(db)
848            .await?
849            .ok_or_else(|| MetaError::catalog_id_not_found("user", user_id))?;
850        let mut user_info: PbUserInfo = user.into();
851        user_info.grant_privileges = get_user_privilege(user_id, db).await?;
852        user_infos.push(user_info);
853    }
854    Ok(user_infos)
855}
856
857/// `get_object_owner` returns the owner of the given object.
858pub async fn get_object_owner<C>(object_id: ObjectId, db: &C) -> MetaResult<UserId>
859where
860    C: ConnectionTrait,
861{
862    let obj_owner: UserId = Object::find_by_id(object_id)
863        .select_only()
864        .column(object::Column::OwnerId)
865        .into_tuple()
866        .one(db)
867        .await?
868        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
869    Ok(obj_owner)
870}
871
872/// `construct_privilege_dependency_query` constructs a query to find all privileges that are dependent on the given one.
873///
874/// # Examples
875///
876/// ```
877/// use risingwave_meta::controller::utils::construct_privilege_dependency_query;
878/// use sea_orm::sea_query::*;
879/// use sea_orm::*;
880///
881/// let query = construct_privilege_dependency_query(vec![1, 2, 3]);
882///
883/// assert_eq!(
884///    query.to_string(MysqlQueryBuilder),
885///   r#"WITH RECURSIVE `granted_privilege_ids` (`id`, `user_id`) AS (SELECT `id`, `user_id` FROM `user_privilege` WHERE `user_privilege`.`id` IN (1, 2, 3) UNION ALL (SELECT `user_privilege`.`id`, `user_privilege`.`user_id` FROM `user_privilege` INNER JOIN `granted_privilege_ids` ON `granted_privilege_ids`.`id` = `dependent_id`)) SELECT `id`, `user_id` FROM `granted_privilege_ids`"#
886/// );
887/// assert_eq!(
888///   query.to_string(PostgresQueryBuilder),
889///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL (SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id")) SELECT "id", "user_id" FROM "granted_privilege_ids""#
890/// );
891/// assert_eq!(
892///  query.to_string(SqliteQueryBuilder),
893///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id") SELECT "id", "user_id" FROM "granted_privilege_ids""#
894/// );
895/// ```
896pub fn construct_privilege_dependency_query(ids: Vec<PrivilegeId>) -> WithQuery {
897    let cte_alias = Alias::new("granted_privilege_ids");
898    let cte_return_privilege_alias = Alias::new("id");
899    let cte_return_user_alias = Alias::new("user_id");
900
901    let mut base_query = SelectStatement::new()
902        .columns([user_privilege::Column::Id, user_privilege::Column::UserId])
903        .from(UserPrivilege)
904        .and_where(user_privilege::Column::Id.is_in(ids))
905        .to_owned();
906
907    let cte_referencing = Query::select()
908        .columns([
909            (UserPrivilege, user_privilege::Column::Id),
910            (UserPrivilege, user_privilege::Column::UserId),
911        ])
912        .from(UserPrivilege)
913        .inner_join(
914            cte_alias.clone(),
915            Expr::col((cte_alias.clone(), cte_return_privilege_alias.clone()))
916                .equals(user_privilege::Column::DependentId),
917        )
918        .to_owned();
919
920    let mut common_table_expr = CommonTableExpression::new();
921    common_table_expr
922        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
923        .columns([
924            cte_return_privilege_alias.clone(),
925            cte_return_user_alias.clone(),
926        ])
927        .table_name(cte_alias.clone());
928
929    SelectStatement::new()
930        .columns([cte_return_privilege_alias, cte_return_user_alias])
931        .from(cte_alias.clone())
932        .to_owned()
933        .with(
934            WithClause::new()
935                .recursive(true)
936                .cte(common_table_expr)
937                .to_owned(),
938        )
939}
940
941pub async fn get_internal_tables_by_id<C>(job_id: ObjectId, db: &C) -> MetaResult<Vec<TableId>>
942where
943    C: ConnectionTrait,
944{
945    let table_ids: Vec<TableId> = Table::find()
946        .select_only()
947        .column(table::Column::TableId)
948        .filter(
949            table::Column::TableType
950                .eq(TableType::Internal)
951                .and(table::Column::BelongsToJobId.eq(job_id)),
952        )
953        .into_tuple()
954        .all(db)
955        .await?;
956    Ok(table_ids)
957}
958
959pub async fn get_index_state_tables_by_table_id<C>(
960    table_id: TableId,
961    db: &C,
962) -> MetaResult<Vec<TableId>>
963where
964    C: ConnectionTrait,
965{
966    let mut index_table_ids: Vec<TableId> = Index::find()
967        .select_only()
968        .column(index::Column::IndexTableId)
969        .filter(index::Column::PrimaryTableId.eq(table_id))
970        .into_tuple()
971        .all(db)
972        .await?;
973
974    if !index_table_ids.is_empty() {
975        let internal_table_ids: Vec<TableId> = Table::find()
976            .select_only()
977            .column(table::Column::TableId)
978            .filter(
979                table::Column::TableType
980                    .eq(TableType::Internal)
981                    .and(table::Column::BelongsToJobId.is_in(index_table_ids.clone())),
982            )
983            .into_tuple()
984            .all(db)
985            .await?;
986
987        index_table_ids.extend(internal_table_ids.into_iter());
988    }
989
990    Ok(index_table_ids)
991}
992
993#[derive(Clone, DerivePartialModel, FromQueryResult)]
994#[sea_orm(entity = "UserPrivilege")]
995pub struct PartialUserPrivilege {
996    pub id: PrivilegeId,
997    pub user_id: UserId,
998}
999
1000pub async fn get_referring_privileges_cascade<C>(
1001    ids: Vec<PrivilegeId>,
1002    db: &C,
1003) -> MetaResult<Vec<PartialUserPrivilege>>
1004where
1005    C: ConnectionTrait,
1006{
1007    let query = construct_privilege_dependency_query(ids);
1008    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
1009    let privileges = PartialUserPrivilege::find_by_statement(Statement::from_sql_and_values(
1010        db.get_database_backend(),
1011        sql,
1012        values,
1013    ))
1014    .all(db)
1015    .await?;
1016
1017    Ok(privileges)
1018}
1019
1020/// `ensure_privileges_not_referred` ensures that the privileges are not granted to any other users.
1021pub async fn ensure_privileges_not_referred<C>(ids: Vec<PrivilegeId>, db: &C) -> MetaResult<()>
1022where
1023    C: ConnectionTrait,
1024{
1025    let count = UserPrivilege::find()
1026        .filter(user_privilege::Column::DependentId.is_in(ids))
1027        .count(db)
1028        .await?;
1029    if count != 0 {
1030        return Err(MetaError::permission_denied(format!(
1031            "privileges granted to {} other ones.",
1032            count
1033        )));
1034    }
1035    Ok(())
1036}
1037
1038/// `get_user_privilege` returns the privileges of the given user.
1039pub async fn get_user_privilege<C>(user_id: UserId, db: &C) -> MetaResult<Vec<PbGrantPrivilege>>
1040where
1041    C: ConnectionTrait,
1042{
1043    let user_privileges = UserPrivilege::find()
1044        .find_also_related(Object)
1045        .filter(user_privilege::Column::UserId.eq(user_id))
1046        .all(db)
1047        .await?;
1048    Ok(user_privileges
1049        .into_iter()
1050        .map(|(privilege, object)| {
1051            let object = object.unwrap();
1052            let oid = object.oid as _;
1053            let obj = match object.obj_type {
1054                ObjectType::Database => PbGrantObject::DatabaseId(oid),
1055                ObjectType::Schema => PbGrantObject::SchemaId(oid),
1056                ObjectType::Table | ObjectType::Index => PbGrantObject::TableId(oid),
1057                ObjectType::Source => PbGrantObject::SourceId(oid),
1058                ObjectType::Sink => PbGrantObject::SinkId(oid),
1059                ObjectType::View => PbGrantObject::ViewId(oid),
1060                ObjectType::Function => PbGrantObject::FunctionId(oid),
1061                ObjectType::Connection => PbGrantObject::ConnectionId(oid),
1062                ObjectType::Subscription => PbGrantObject::SubscriptionId(oid),
1063                ObjectType::Secret => PbGrantObject::SecretId(oid),
1064            };
1065            PbGrantPrivilege {
1066                action_with_opts: vec![PbActionWithGrantOption {
1067                    action: PbAction::from(privilege.action) as _,
1068                    with_grant_option: privilege.with_grant_option,
1069                    granted_by: privilege.granted_by as _,
1070                }],
1071                object: Some(obj),
1072            }
1073        })
1074        .collect())
1075}
1076
1077pub async fn get_table_columns(
1078    txn: &impl ConnectionTrait,
1079    id: TableId,
1080) -> MetaResult<ColumnCatalogArray> {
1081    let columns = Table::find_by_id(id)
1082        .select_only()
1083        .columns([table::Column::Columns])
1084        .into_tuple::<ColumnCatalogArray>()
1085        .one(txn)
1086        .await?
1087        .ok_or_else(|| MetaError::catalog_id_not_found("table", id))?;
1088    Ok(columns)
1089}
1090
1091/// `grant_default_privileges_automatically` grants default privileges automatically
1092/// for the given new object. It returns the list of user infos whose privileges are updated.
1093pub async fn grant_default_privileges_automatically<C>(
1094    db: &C,
1095    object_id: ObjectId,
1096) -> MetaResult<Vec<PbUserInfo>>
1097where
1098    C: ConnectionTrait,
1099{
1100    let object = Object::find_by_id(object_id)
1101        .one(db)
1102        .await?
1103        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
1104    assert_ne!(object.obj_type, ObjectType::Database);
1105
1106    let for_mview_filter = if object.obj_type == ObjectType::Table {
1107        let table_type = Table::find_by_id(object_id)
1108            .select_only()
1109            .column(table::Column::TableType)
1110            .into_tuple::<TableType>()
1111            .one(db)
1112            .await?
1113            .ok_or_else(|| MetaError::catalog_id_not_found("table", object_id))?;
1114        user_default_privilege::Column::ForMaterializedView
1115            .eq(table_type == TableType::MaterializedView)
1116    } else {
1117        user_default_privilege::Column::ForMaterializedView.eq(false)
1118    };
1119    let schema_filter = if let Some(schema_id) = &object.schema_id {
1120        user_default_privilege::Column::SchemaId.eq(*schema_id)
1121    } else {
1122        user_default_privilege::Column::SchemaId.is_null()
1123    };
1124
1125    let default_privileges: Vec<(UserId, UserId, Action, bool)> = UserDefaultPrivilege::find()
1126        .select_only()
1127        .columns([
1128            user_default_privilege::Column::Grantee,
1129            user_default_privilege::Column::GrantedBy,
1130            user_default_privilege::Column::Action,
1131            user_default_privilege::Column::WithGrantOption,
1132        ])
1133        .filter(
1134            user_default_privilege::Column::DatabaseId
1135                .eq(object.database_id.unwrap())
1136                .and(schema_filter)
1137                .and(user_default_privilege::Column::UserId.eq(object.owner_id))
1138                .and(user_default_privilege::Column::ObjectType.eq(object.obj_type))
1139                .and(for_mview_filter),
1140        )
1141        .into_tuple()
1142        .all(db)
1143        .await?;
1144    if default_privileges.is_empty() {
1145        return Ok(vec![]);
1146    }
1147
1148    let updated_user_ids = default_privileges
1149        .iter()
1150        .map(|(grantee, _, _, _)| *grantee)
1151        .collect::<HashSet<_>>();
1152
1153    let internal_table_ids = get_internal_tables_by_id(object_id, db).await?;
1154
1155    for (grantee, granted_by, action, with_grant_option) in default_privileges {
1156        UserPrivilege::insert(user_privilege::ActiveModel {
1157            user_id: Set(grantee),
1158            oid: Set(object_id),
1159            granted_by: Set(granted_by),
1160            action: Set(action),
1161            with_grant_option: Set(with_grant_option),
1162            ..Default::default()
1163        })
1164        .exec(db)
1165        .await?;
1166        if action == Action::Select && !internal_table_ids.is_empty() {
1167            // Grant SELECT privilege for internal tables if the action is SELECT.
1168            for internal_table_id in &internal_table_ids {
1169                UserPrivilege::insert(user_privilege::ActiveModel {
1170                    user_id: Set(grantee),
1171                    oid: Set(*internal_table_id as _),
1172                    granted_by: Set(granted_by),
1173                    action: Set(Action::Select),
1174                    with_grant_option: Set(with_grant_option),
1175                    ..Default::default()
1176                })
1177                .exec(db)
1178                .await?;
1179            }
1180        }
1181    }
1182
1183    let updated_user_infos = list_user_info_by_ids(updated_user_ids, db).await?;
1184    Ok(updated_user_infos)
1185}
1186
1187// todo: remove it after migrated to sql backend.
1188pub fn extract_grant_obj_id(object: &PbGrantObject) -> ObjectId {
1189    match object {
1190        PbGrantObject::DatabaseId(id)
1191        | PbGrantObject::SchemaId(id)
1192        | PbGrantObject::TableId(id)
1193        | PbGrantObject::SourceId(id)
1194        | PbGrantObject::SinkId(id)
1195        | PbGrantObject::ViewId(id)
1196        | PbGrantObject::FunctionId(id)
1197        | PbGrantObject::SubscriptionId(id)
1198        | PbGrantObject::ConnectionId(id)
1199        | PbGrantObject::SecretId(id) => *id as _,
1200    }
1201}
1202
1203pub async fn insert_fragment_relations(
1204    db: &impl ConnectionTrait,
1205    downstream_fragment_relations: &FragmentDownstreamRelation,
1206) -> MetaResult<()> {
1207    for (upstream_fragment_id, downstreams) in downstream_fragment_relations {
1208        for downstream in downstreams {
1209            let relation = fragment_relation::Model {
1210                source_fragment_id: *upstream_fragment_id as _,
1211                target_fragment_id: downstream.downstream_fragment_id as _,
1212                dispatcher_type: downstream.dispatcher_type,
1213                dist_key_indices: downstream
1214                    .dist_key_indices
1215                    .iter()
1216                    .map(|idx| *idx as i32)
1217                    .collect_vec()
1218                    .into(),
1219                output_indices: downstream
1220                    .output_mapping
1221                    .indices
1222                    .iter()
1223                    .map(|idx| *idx as i32)
1224                    .collect_vec()
1225                    .into(),
1226                output_type_mapping: Some(downstream.output_mapping.types.clone().into()),
1227            };
1228            FragmentRelation::insert(relation.into_active_model())
1229                .exec(db)
1230                .await?;
1231        }
1232    }
1233    Ok(())
1234}
1235
1236pub async fn get_fragment_actor_dispatchers<C>(
1237    db: &C,
1238    fragment_ids: Vec<FragmentId>,
1239) -> MetaResult<FragmentActorDispatchers>
1240where
1241    C: ConnectionTrait,
1242{
1243    type FragmentActorInfo = (
1244        DistributionType,
1245        Arc<HashMap<crate::model::ActorId, Option<Bitmap>>>,
1246    );
1247    let mut fragment_actor_cache: HashMap<FragmentId, FragmentActorInfo> = HashMap::new();
1248    let get_fragment_actors = |fragment_id: FragmentId| async move {
1249        let result: MetaResult<FragmentActorInfo> = try {
1250            let mut fragment_actors = Fragment::find_by_id(fragment_id)
1251                .find_with_related(Actor)
1252                .filter(actor::Column::Status.eq(ActorStatus::Running))
1253                .all(db)
1254                .await?;
1255            if fragment_actors.is_empty() {
1256                return Err(anyhow!("failed to find fragment: {}", fragment_id).into());
1257            }
1258            assert_eq!(
1259                fragment_actors.len(),
1260                1,
1261                "find multiple fragment {:?}",
1262                fragment_actors
1263            );
1264            let (fragment, actors) = fragment_actors.pop().unwrap();
1265            (
1266                fragment.distribution_type,
1267                Arc::new(
1268                    actors
1269                        .into_iter()
1270                        .map(|actor| {
1271                            (
1272                                actor.actor_id as _,
1273                                actor
1274                                    .vnode_bitmap
1275                                    .map(|bitmap| Bitmap::from(bitmap.to_protobuf())),
1276                            )
1277                        })
1278                        .collect(),
1279                ),
1280            )
1281        };
1282        result
1283    };
1284    let fragment_relations = FragmentRelation::find()
1285        .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids))
1286        .all(db)
1287        .await?;
1288
1289    let mut actor_dispatchers_map: HashMap<_, HashMap<_, Vec<_>>> = HashMap::new();
1290    for fragment_relation::Model {
1291        source_fragment_id,
1292        target_fragment_id,
1293        dispatcher_type,
1294        dist_key_indices,
1295        output_indices,
1296        output_type_mapping,
1297    } in fragment_relations
1298    {
1299        let (source_fragment_distribution, source_fragment_actors) = {
1300            let (distribution, actors) = {
1301                match fragment_actor_cache.entry(source_fragment_id) {
1302                    Entry::Occupied(entry) => entry.into_mut(),
1303                    Entry::Vacant(entry) => {
1304                        entry.insert(get_fragment_actors(source_fragment_id).await?)
1305                    }
1306                }
1307            };
1308            (*distribution, actors.clone())
1309        };
1310        let (target_fragment_distribution, target_fragment_actors) = {
1311            let (distribution, actors) = {
1312                match fragment_actor_cache.entry(target_fragment_id) {
1313                    Entry::Occupied(entry) => entry.into_mut(),
1314                    Entry::Vacant(entry) => {
1315                        entry.insert(get_fragment_actors(target_fragment_id).await?)
1316                    }
1317                }
1318            };
1319            (*distribution, actors.clone())
1320        };
1321        let output_mapping = PbDispatchOutputMapping {
1322            indices: output_indices.into_u32_array(),
1323            types: output_type_mapping.unwrap_or_default().to_protobuf(),
1324        };
1325        let dispatchers = compose_dispatchers(
1326            source_fragment_distribution,
1327            &source_fragment_actors,
1328            target_fragment_id as _,
1329            target_fragment_distribution,
1330            &target_fragment_actors,
1331            dispatcher_type,
1332            dist_key_indices.into_u32_array(),
1333            output_mapping,
1334        );
1335        let actor_dispatchers_map = actor_dispatchers_map
1336            .entry(source_fragment_id as _)
1337            .or_default();
1338        for (actor_id, dispatchers) in dispatchers {
1339            actor_dispatchers_map
1340                .entry(actor_id as _)
1341                .or_default()
1342                .push(dispatchers);
1343        }
1344    }
1345    Ok(actor_dispatchers_map)
1346}
1347
1348pub fn compose_dispatchers(
1349    source_fragment_distribution: DistributionType,
1350    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1351    target_fragment_id: crate::model::FragmentId,
1352    target_fragment_distribution: DistributionType,
1353    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1354    dispatcher_type: DispatcherType,
1355    dist_key_indices: Vec<u32>,
1356    output_mapping: PbDispatchOutputMapping,
1357) -> HashMap<crate::model::ActorId, PbDispatcher> {
1358    match dispatcher_type {
1359        DispatcherType::Hash => {
1360            let dispatcher = PbDispatcher {
1361                r#type: PbDispatcherType::from(dispatcher_type) as _,
1362                dist_key_indices: dist_key_indices.clone(),
1363                output_mapping: output_mapping.into(),
1364                hash_mapping: Some(
1365                    ActorMapping::from_bitmaps(
1366                        &target_fragment_actors
1367                            .iter()
1368                            .map(|(actor_id, bitmap)| {
1369                                (
1370                                    *actor_id as _,
1371                                    bitmap
1372                                        .clone()
1373                                        .expect("downstream hash dispatch must have distribution"),
1374                                )
1375                            })
1376                            .collect(),
1377                    )
1378                    .to_protobuf(),
1379                ),
1380                dispatcher_id: target_fragment_id as _,
1381                downstream_actor_id: target_fragment_actors
1382                    .keys()
1383                    .map(|actor_id| *actor_id as _)
1384                    .collect(),
1385            };
1386            source_fragment_actors
1387                .keys()
1388                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1389                .collect()
1390        }
1391        DispatcherType::Broadcast | DispatcherType::Simple => {
1392            let dispatcher = PbDispatcher {
1393                r#type: PbDispatcherType::from(dispatcher_type) as _,
1394                dist_key_indices: dist_key_indices.clone(),
1395                output_mapping: output_mapping.into(),
1396                hash_mapping: None,
1397                dispatcher_id: target_fragment_id as _,
1398                downstream_actor_id: target_fragment_actors
1399                    .keys()
1400                    .map(|actor_id| *actor_id as _)
1401                    .collect(),
1402            };
1403            source_fragment_actors
1404                .keys()
1405                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1406                .collect()
1407        }
1408        DispatcherType::NoShuffle => resolve_no_shuffle_actor_dispatcher(
1409            source_fragment_distribution,
1410            source_fragment_actors,
1411            target_fragment_distribution,
1412            target_fragment_actors,
1413        )
1414        .into_iter()
1415        .map(|(upstream_actor_id, downstream_actor_id)| {
1416            (
1417                upstream_actor_id,
1418                PbDispatcher {
1419                    r#type: PbDispatcherType::NoShuffle as _,
1420                    dist_key_indices: dist_key_indices.clone(),
1421                    output_mapping: output_mapping.clone().into(),
1422                    hash_mapping: None,
1423                    dispatcher_id: target_fragment_id as _,
1424                    downstream_actor_id: vec![downstream_actor_id as _],
1425                },
1426            )
1427        })
1428        .collect(),
1429    }
1430}
1431
1432/// return (`upstream_actor_id` -> `downstream_actor_id`)
1433pub fn resolve_no_shuffle_actor_dispatcher(
1434    source_fragment_distribution: DistributionType,
1435    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1436    target_fragment_distribution: DistributionType,
1437    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1438) -> Vec<(crate::model::ActorId, crate::model::ActorId)> {
1439    assert_eq!(source_fragment_distribution, target_fragment_distribution);
1440    assert_eq!(
1441        source_fragment_actors.len(),
1442        target_fragment_actors.len(),
1443        "no-shuffle should have equal upstream downstream actor count: {:?} {:?}",
1444        source_fragment_actors,
1445        target_fragment_actors
1446    );
1447    match source_fragment_distribution {
1448        DistributionType::Single => {
1449            let assert_singleton = |bitmap: &Option<Bitmap>| {
1450                assert!(
1451                    bitmap.as_ref().map(|bitmap| bitmap.all()).unwrap_or(true),
1452                    "not singleton: {:?}",
1453                    bitmap
1454                );
1455            };
1456            assert_eq!(
1457                source_fragment_actors.len(),
1458                1,
1459                "singleton distribution actor count not 1: {:?}",
1460                source_fragment_distribution
1461            );
1462            assert_eq!(
1463                target_fragment_actors.len(),
1464                1,
1465                "singleton distribution actor count not 1: {:?}",
1466                target_fragment_distribution
1467            );
1468            let (source_actor_id, bitmap) = source_fragment_actors.iter().next().unwrap();
1469            assert_singleton(bitmap);
1470            let (target_actor_id, bitmap) = target_fragment_actors.iter().next().unwrap();
1471            assert_singleton(bitmap);
1472            vec![(*source_actor_id, *target_actor_id)]
1473        }
1474        DistributionType::Hash => {
1475            let mut target_fragment_actor_index: HashMap<_, _> = target_fragment_actors
1476                .iter()
1477                .map(|(actor_id, bitmap)| {
1478                    let bitmap = bitmap
1479                        .as_ref()
1480                        .expect("hash distribution should have bitmap");
1481                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1482                    (first_vnode, (*actor_id, bitmap))
1483                })
1484                .collect();
1485            source_fragment_actors
1486                .iter()
1487                .map(|(source_actor_id, bitmap)| {
1488                    let bitmap = bitmap
1489                        .as_ref()
1490                        .expect("hash distribution should have bitmap");
1491                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1492                    let (target_actor_id, target_bitmap) =
1493                        target_fragment_actor_index.remove(&first_vnode).unwrap_or_else(|| {
1494                            panic!(
1495                                "cannot find matched target actor: {} {:?} {:?} {:?}",
1496                                source_actor_id,
1497                                first_vnode,
1498                                source_fragment_actors,
1499                                target_fragment_actors
1500                            );
1501                        });
1502                    assert_eq!(
1503                        bitmap,
1504                        target_bitmap,
1505                        "cannot find matched target actor due to bitmap mismatch: {} {:?} {:?} {:?}",
1506                        source_actor_id,
1507                        first_vnode,
1508                        source_fragment_actors,
1509                        target_fragment_actors
1510                    );
1511                    (*source_actor_id, target_actor_id)
1512                }).collect()
1513        }
1514    }
1515}
1516
1517/// `get_fragment_mappings` returns the fragment vnode mappings of the given job.
1518pub async fn get_fragment_mappings<C>(
1519    db: &C,
1520    job_id: ObjectId,
1521) -> MetaResult<Vec<PbFragmentWorkerSlotMapping>>
1522where
1523    C: ConnectionTrait,
1524{
1525    let job_actors: Vec<(
1526        FragmentId,
1527        DistributionType,
1528        ActorId,
1529        Option<VnodeBitmap>,
1530        WorkerId,
1531        ActorStatus,
1532    )> = Actor::find()
1533        .select_only()
1534        .columns([
1535            fragment::Column::FragmentId,
1536            fragment::Column::DistributionType,
1537        ])
1538        .columns([
1539            actor::Column::ActorId,
1540            actor::Column::VnodeBitmap,
1541            actor::Column::WorkerId,
1542            actor::Column::Status,
1543        ])
1544        .join(JoinType::InnerJoin, actor::Relation::Fragment.def())
1545        .filter(fragment::Column::JobId.eq(job_id))
1546        .into_tuple()
1547        .all(db)
1548        .await?;
1549
1550    Ok(rebuild_fragment_mapping_from_actors(job_actors))
1551}
1552
1553pub fn rebuild_fragment_mapping(fragment: &SharedFragmentInfo) -> PbFragmentWorkerSlotMapping {
1554    let fragment_worker_slot_mapping = match fragment.distribution_type {
1555        DistributionType::Single => {
1556            let actor = fragment.actors.values().exactly_one().unwrap();
1557            WorkerSlotMapping::new_single(WorkerSlotId::new(actor.worker_id as _, 0))
1558        }
1559        DistributionType::Hash => {
1560            let actor_bitmaps: HashMap<_, _> = fragment
1561                .actors
1562                .iter()
1563                .map(|(actor_id, actor_info)| {
1564                    let vnode_bitmap = actor_info
1565                        .vnode_bitmap
1566                        .as_ref()
1567                        .cloned()
1568                        .expect("actor bitmap shouldn't be none in hash fragment");
1569
1570                    (*actor_id as hash::ActorId, vnode_bitmap)
1571                })
1572                .collect();
1573
1574            let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
1575
1576            let actor_locations = fragment
1577                .actors
1578                .iter()
1579                .map(|(actor_id, actor_info)| {
1580                    (*actor_id as hash::ActorId, actor_info.worker_id as u32)
1581                })
1582                .collect();
1583
1584            actor_mapping.to_worker_slot(&actor_locations)
1585        }
1586    };
1587
1588    PbFragmentWorkerSlotMapping {
1589        fragment_id: fragment.fragment_id,
1590        mapping: Some(fragment_worker_slot_mapping.to_protobuf()),
1591    }
1592}
1593
1594pub fn rebuild_fragment_mapping_from_actors(
1595    job_actors: Vec<(
1596        FragmentId,
1597        DistributionType,
1598        ActorId,
1599        Option<VnodeBitmap>,
1600        WorkerId,
1601        ActorStatus,
1602    )>,
1603) -> Vec<FragmentWorkerSlotMapping> {
1604    let mut all_actor_locations = HashMap::new();
1605    let mut actor_bitmaps = HashMap::new();
1606    let mut fragment_actors = HashMap::new();
1607    let mut fragment_dist = HashMap::new();
1608
1609    for (fragment_id, dist, actor_id, bitmap, worker_id, actor_status) in job_actors {
1610        if actor_status == ActorStatus::Inactive {
1611            continue;
1612        }
1613
1614        all_actor_locations
1615            .entry(fragment_id)
1616            .or_insert(HashMap::new())
1617            .insert(actor_id as hash::ActorId, worker_id as u32);
1618        actor_bitmaps.insert(actor_id, bitmap);
1619        fragment_actors
1620            .entry(fragment_id)
1621            .or_insert_with(Vec::new)
1622            .push(actor_id);
1623        fragment_dist.insert(fragment_id, dist);
1624    }
1625
1626    let mut result = vec![];
1627    for (fragment_id, dist) in fragment_dist {
1628        let mut actor_locations = all_actor_locations.remove(&fragment_id).unwrap();
1629        let fragment_worker_slot_mapping = match dist {
1630            DistributionType::Single => {
1631                let actor = fragment_actors
1632                    .remove(&fragment_id)
1633                    .unwrap()
1634                    .into_iter()
1635                    .exactly_one()
1636                    .unwrap() as hash::ActorId;
1637                let actor_location = actor_locations.remove(&actor).unwrap();
1638
1639                WorkerSlotMapping::new_single(WorkerSlotId::new(actor_location, 0))
1640            }
1641            DistributionType::Hash => {
1642                let actors = fragment_actors.remove(&fragment_id).unwrap();
1643
1644                let all_actor_bitmaps: HashMap<_, _> = actors
1645                    .iter()
1646                    .map(|actor_id| {
1647                        let vnode_bitmap = actor_bitmaps
1648                            .remove(actor_id)
1649                            .flatten()
1650                            .expect("actor bitmap shouldn't be none in hash fragment");
1651
1652                        let bitmap = Bitmap::from(&vnode_bitmap.to_protobuf());
1653                        (*actor_id as hash::ActorId, bitmap)
1654                    })
1655                    .collect();
1656
1657                let actor_mapping = ActorMapping::from_bitmaps(&all_actor_bitmaps);
1658
1659                actor_mapping.to_worker_slot(&actor_locations)
1660            }
1661        };
1662
1663        result.push(PbFragmentWorkerSlotMapping {
1664            fragment_id: fragment_id as u32,
1665            mapping: Some(fragment_worker_slot_mapping.to_protobuf()),
1666        })
1667    }
1668    result
1669}
1670
1671/// `get_fragment_actor_ids` returns the fragment actor ids of the given fragments.
1672pub async fn get_fragment_actor_ids<C>(
1673    db: &C,
1674    fragment_ids: Vec<FragmentId>,
1675) -> MetaResult<HashMap<FragmentId, Vec<ActorId>>>
1676where
1677    C: ConnectionTrait,
1678{
1679    let fragment_actors: Vec<(FragmentId, ActorId)> = Actor::find()
1680        .select_only()
1681        .columns([actor::Column::FragmentId, actor::Column::ActorId])
1682        .filter(actor::Column::FragmentId.is_in(fragment_ids))
1683        .into_tuple()
1684        .all(db)
1685        .await?;
1686
1687    Ok(fragment_actors.into_iter().into_group_map())
1688}
1689
1690/// For the given streaming jobs, returns
1691/// - All source fragments
1692/// - All sink fragments
1693/// - All actors
1694/// - All fragments
1695pub async fn get_fragments_for_jobs<C>(
1696    db: &C,
1697    streaming_jobs: Vec<ObjectId>,
1698) -> MetaResult<(
1699    HashMap<SourceId, BTreeSet<FragmentId>>,
1700    HashSet<FragmentId>,
1701    HashSet<ActorId>,
1702    HashSet<FragmentId>,
1703)>
1704where
1705    C: ConnectionTrait,
1706{
1707    if streaming_jobs.is_empty() {
1708        return Ok((
1709            HashMap::default(),
1710            HashSet::default(),
1711            HashSet::default(),
1712            HashSet::default(),
1713        ));
1714    }
1715
1716    let fragments: Vec<(FragmentId, i32, StreamNode)> = Fragment::find()
1717        .select_only()
1718        .columns([
1719            fragment::Column::FragmentId,
1720            fragment::Column::FragmentTypeMask,
1721            fragment::Column::StreamNode,
1722        ])
1723        .filter(fragment::Column::JobId.is_in(streaming_jobs))
1724        .into_tuple()
1725        .all(db)
1726        .await?;
1727    let actors: Vec<ActorId> = Actor::find()
1728        .select_only()
1729        .column(actor::Column::ActorId)
1730        .filter(
1731            actor::Column::FragmentId.is_in(fragments.iter().map(|(id, _, _)| *id).collect_vec()),
1732        )
1733        .into_tuple()
1734        .all(db)
1735        .await?;
1736
1737    let fragment_ids = fragments
1738        .iter()
1739        .map(|(fragment_id, _, _)| *fragment_id)
1740        .collect();
1741
1742    let mut source_fragment_ids: HashMap<SourceId, BTreeSet<FragmentId>> = HashMap::new();
1743    let mut sink_fragment_ids: HashSet<FragmentId> = HashSet::new();
1744    for (fragment_id, mask, stream_node) in fragments {
1745        if FragmentTypeMask::from(mask).contains(FragmentTypeFlag::Source)
1746            && let Some(source_id) = stream_node.to_protobuf().find_stream_source()
1747        {
1748            source_fragment_ids
1749                .entry(source_id as _)
1750                .or_default()
1751                .insert(fragment_id);
1752        }
1753        if FragmentTypeMask::from(mask).contains(FragmentTypeFlag::Sink) {
1754            sink_fragment_ids.insert(fragment_id);
1755        }
1756    }
1757
1758    Ok((
1759        source_fragment_ids,
1760        sink_fragment_ids,
1761        actors.into_iter().collect(),
1762        fragment_ids,
1763    ))
1764}
1765
1766/// Build a object group for notifying the deletion of the given objects.
1767///
1768/// Note that only id fields are filled in the object info, as the arguments are partial objects.
1769/// As a result, the returned notification info should only be used for deletion.
1770pub(crate) fn build_object_group_for_delete(
1771    partial_objects: Vec<PartialObject>,
1772) -> NotificationInfo {
1773    let mut objects = vec![];
1774    for obj in partial_objects {
1775        match obj.obj_type {
1776            ObjectType::Database => objects.push(PbObject {
1777                object_info: Some(PbObjectInfo::Database(PbDatabase {
1778                    id: obj.oid as _,
1779                    ..Default::default()
1780                })),
1781            }),
1782            ObjectType::Schema => objects.push(PbObject {
1783                object_info: Some(PbObjectInfo::Schema(PbSchema {
1784                    id: obj.oid as _,
1785                    database_id: obj.database_id.unwrap() as _,
1786                    ..Default::default()
1787                })),
1788            }),
1789            ObjectType::Table => objects.push(PbObject {
1790                object_info: Some(PbObjectInfo::Table(PbTable {
1791                    id: obj.oid as _,
1792                    schema_id: obj.schema_id.unwrap() as _,
1793                    database_id: obj.database_id.unwrap() as _,
1794                    ..Default::default()
1795                })),
1796            }),
1797            ObjectType::Source => objects.push(PbObject {
1798                object_info: Some(PbObjectInfo::Source(PbSource {
1799                    id: obj.oid as _,
1800                    schema_id: obj.schema_id.unwrap() as _,
1801                    database_id: obj.database_id.unwrap() as _,
1802                    ..Default::default()
1803                })),
1804            }),
1805            ObjectType::Sink => objects.push(PbObject {
1806                object_info: Some(PbObjectInfo::Sink(PbSink {
1807                    id: obj.oid as _,
1808                    schema_id: obj.schema_id.unwrap() as _,
1809                    database_id: obj.database_id.unwrap() as _,
1810                    ..Default::default()
1811                })),
1812            }),
1813            ObjectType::Subscription => objects.push(PbObject {
1814                object_info: Some(PbObjectInfo::Subscription(PbSubscription {
1815                    id: obj.oid as _,
1816                    schema_id: obj.schema_id.unwrap() as _,
1817                    database_id: obj.database_id.unwrap() as _,
1818                    ..Default::default()
1819                })),
1820            }),
1821            ObjectType::View => objects.push(PbObject {
1822                object_info: Some(PbObjectInfo::View(PbView {
1823                    id: obj.oid as _,
1824                    schema_id: obj.schema_id.unwrap() as _,
1825                    database_id: obj.database_id.unwrap() as _,
1826                    ..Default::default()
1827                })),
1828            }),
1829            ObjectType::Index => {
1830                objects.push(PbObject {
1831                    object_info: Some(PbObjectInfo::Index(PbIndex {
1832                        id: obj.oid as _,
1833                        schema_id: obj.schema_id.unwrap() as _,
1834                        database_id: obj.database_id.unwrap() as _,
1835                        ..Default::default()
1836                    })),
1837                });
1838                objects.push(PbObject {
1839                    object_info: Some(PbObjectInfo::Table(PbTable {
1840                        id: obj.oid as _,
1841                        schema_id: obj.schema_id.unwrap() as _,
1842                        database_id: obj.database_id.unwrap() as _,
1843                        ..Default::default()
1844                    })),
1845                });
1846            }
1847            ObjectType::Function => objects.push(PbObject {
1848                object_info: Some(PbObjectInfo::Function(PbFunction {
1849                    id: obj.oid as _,
1850                    schema_id: obj.schema_id.unwrap() as _,
1851                    database_id: obj.database_id.unwrap() as _,
1852                    ..Default::default()
1853                })),
1854            }),
1855            ObjectType::Connection => objects.push(PbObject {
1856                object_info: Some(PbObjectInfo::Connection(PbConnection {
1857                    id: obj.oid as _,
1858                    schema_id: obj.schema_id.unwrap() as _,
1859                    database_id: obj.database_id.unwrap() as _,
1860                    ..Default::default()
1861                })),
1862            }),
1863            ObjectType::Secret => objects.push(PbObject {
1864                object_info: Some(PbObjectInfo::Secret(PbSecret {
1865                    id: obj.oid as _,
1866                    schema_id: obj.schema_id.unwrap() as _,
1867                    database_id: obj.database_id.unwrap() as _,
1868                    ..Default::default()
1869                })),
1870            }),
1871        }
1872    }
1873    NotificationInfo::ObjectGroup(PbObjectGroup { objects })
1874}
1875
1876pub fn extract_external_table_name_from_definition(table_definition: &str) -> Option<String> {
1877    let [mut definition]: [_; 1] = Parser::parse_sql(table_definition)
1878        .context("unable to parse table definition")
1879        .inspect_err(|e| {
1880            tracing::error!(
1881                target: "auto_schema_change",
1882                error = %e.as_report(),
1883                "failed to parse table definition")
1884        })
1885        .unwrap()
1886        .try_into()
1887        .unwrap();
1888    if let SqlStatement::CreateTable { cdc_table_info, .. } = &mut definition {
1889        cdc_table_info
1890            .clone()
1891            .map(|cdc_table_info| cdc_table_info.external_table_name)
1892    } else {
1893        None
1894    }
1895}
1896
1897/// `rename_relation` renames the target relation and its definition,
1898/// it commits the changes to the transaction and returns the updated relations and the old name.
1899pub async fn rename_relation(
1900    txn: &DatabaseTransaction,
1901    object_type: ObjectType,
1902    object_id: ObjectId,
1903    object_name: &str,
1904) -> MetaResult<(Vec<PbObject>, String)> {
1905    use sea_orm::ActiveModelTrait;
1906
1907    use crate::controller::rename::alter_relation_rename;
1908
1909    let mut to_update_relations = vec![];
1910    // rename relation.
1911    macro_rules! rename_relation {
1912        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
1913            let (mut relation, obj) = $entity::find_by_id($object_id)
1914                .find_also_related(Object)
1915                .one(txn)
1916                .await?
1917                .unwrap();
1918            let obj = obj.unwrap();
1919            let old_name = relation.name.clone();
1920            relation.name = object_name.into();
1921            if obj.obj_type != ObjectType::View {
1922                relation.definition = alter_relation_rename(&relation.definition, object_name);
1923            }
1924            let active_model = $table::ActiveModel {
1925                $identity: Set(relation.$identity),
1926                name: Set(object_name.into()),
1927                definition: Set(relation.definition.clone()),
1928                ..Default::default()
1929            };
1930            active_model.update(txn).await?;
1931            to_update_relations.push(PbObject {
1932                object_info: Some(PbObjectInfo::$entity(ObjectModel(relation, obj).into())),
1933            });
1934            old_name
1935        }};
1936    }
1937    // TODO: check is there any thing to change for shared source?
1938    let old_name = match object_type {
1939        ObjectType::Table => {
1940            let associated_source_id: Option<SourceId> = Source::find()
1941                .select_only()
1942                .column(source::Column::SourceId)
1943                .filter(source::Column::OptionalAssociatedTableId.eq(object_id))
1944                .into_tuple()
1945                .one(txn)
1946                .await?;
1947            if let Some(source_id) = associated_source_id {
1948                rename_relation!(Source, source, source_id, source_id);
1949            }
1950            rename_relation!(Table, table, table_id, object_id)
1951        }
1952        ObjectType::Source => rename_relation!(Source, source, source_id, object_id),
1953        ObjectType::Sink => rename_relation!(Sink, sink, sink_id, object_id),
1954        ObjectType::Subscription => {
1955            rename_relation!(Subscription, subscription, subscription_id, object_id)
1956        }
1957        ObjectType::View => rename_relation!(View, view, view_id, object_id),
1958        ObjectType::Index => {
1959            let (mut index, obj) = Index::find_by_id(object_id)
1960                .find_also_related(Object)
1961                .one(txn)
1962                .await?
1963                .unwrap();
1964            index.name = object_name.into();
1965            let index_table_id = index.index_table_id;
1966            let old_name = rename_relation!(Table, table, table_id, index_table_id);
1967
1968            // the name of index and its associated table is the same.
1969            let active_model = index::ActiveModel {
1970                index_id: sea_orm::ActiveValue::Set(index.index_id),
1971                name: sea_orm::ActiveValue::Set(object_name.into()),
1972                ..Default::default()
1973            };
1974            active_model.update(txn).await?;
1975            to_update_relations.push(PbObject {
1976                object_info: Some(PbObjectInfo::Index(ObjectModel(index, obj.unwrap()).into())),
1977            });
1978            old_name
1979        }
1980        _ => unreachable!("only relation name can be altered."),
1981    };
1982
1983    Ok((to_update_relations, old_name))
1984}
1985
1986pub async fn get_database_resource_group<C>(txn: &C, database_id: ObjectId) -> MetaResult<String>
1987where
1988    C: ConnectionTrait,
1989{
1990    let database_resource_group: Option<String> = Database::find_by_id(database_id)
1991        .select_only()
1992        .column(database::Column::ResourceGroup)
1993        .into_tuple()
1994        .one(txn)
1995        .await?
1996        .ok_or_else(|| MetaError::catalog_id_not_found("database", database_id))?;
1997
1998    Ok(database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned()))
1999}
2000
2001pub async fn get_existing_job_resource_group<C>(
2002    txn: &C,
2003    streaming_job_id: ObjectId,
2004) -> MetaResult<String>
2005where
2006    C: ConnectionTrait,
2007{
2008    let (job_specific_resource_group, database_resource_group): (Option<String>, Option<String>) =
2009        StreamingJob::find_by_id(streaming_job_id)
2010            .select_only()
2011            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
2012            .join(JoinType::InnerJoin, object::Relation::Database2.def())
2013            .column(streaming_job::Column::SpecificResourceGroup)
2014            .column(database::Column::ResourceGroup)
2015            .into_tuple()
2016            .one(txn)
2017            .await?
2018            .ok_or_else(|| MetaError::catalog_id_not_found("streaming job", streaming_job_id))?;
2019
2020    Ok(job_specific_resource_group.unwrap_or_else(|| {
2021        database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned())
2022    }))
2023}
2024
2025pub fn filter_workers_by_resource_group(
2026    workers: &HashMap<u32, WorkerNode>,
2027    resource_group: &str,
2028) -> BTreeSet<WorkerId> {
2029    workers
2030        .iter()
2031        .filter(|&(_, worker)| {
2032            worker
2033                .resource_group()
2034                .map(|node_label| node_label.as_str() == resource_group)
2035                .unwrap_or(false)
2036        })
2037        .map(|(id, _)| *id as WorkerId)
2038        .collect()
2039}
2040
2041/// `rename_relation_refer` updates the definition of relations that refer to the target one,
2042/// it commits the changes to the transaction and returns all the updated relations.
2043pub async fn rename_relation_refer(
2044    txn: &DatabaseTransaction,
2045    object_type: ObjectType,
2046    object_id: ObjectId,
2047    object_name: &str,
2048    old_name: &str,
2049) -> MetaResult<Vec<PbObject>> {
2050    use sea_orm::ActiveModelTrait;
2051
2052    use crate::controller::rename::alter_relation_rename_refs;
2053
2054    let mut to_update_relations = vec![];
2055    macro_rules! rename_relation_ref {
2056        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
2057            let (mut relation, obj) = $entity::find_by_id($object_id)
2058                .find_also_related(Object)
2059                .one(txn)
2060                .await?
2061                .unwrap();
2062            relation.definition =
2063                alter_relation_rename_refs(&relation.definition, old_name, object_name);
2064            let active_model = $table::ActiveModel {
2065                $identity: Set(relation.$identity),
2066                definition: Set(relation.definition.clone()),
2067                ..Default::default()
2068            };
2069            active_model.update(txn).await?;
2070            to_update_relations.push(PbObject {
2071                object_info: Some(PbObjectInfo::$entity(
2072                    ObjectModel(relation, obj.unwrap()).into(),
2073                )),
2074            });
2075        }};
2076    }
2077    let mut objs = get_referring_objects(object_id, txn).await?;
2078    if object_type == ObjectType::Table {
2079        let incoming_sinks: Vec<SinkId> = Sink::find()
2080            .select_only()
2081            .column(sink::Column::SinkId)
2082            .filter(sink::Column::TargetTable.eq(object_id))
2083            .into_tuple()
2084            .all(txn)
2085            .await?;
2086
2087        objs.extend(incoming_sinks.into_iter().map(|id| PartialObject {
2088            oid: id,
2089            obj_type: ObjectType::Sink,
2090            schema_id: None,
2091            database_id: None,
2092        }));
2093    }
2094
2095    for obj in objs {
2096        match obj.obj_type {
2097            ObjectType::Table => rename_relation_ref!(Table, table, table_id, obj.oid),
2098            ObjectType::Sink => rename_relation_ref!(Sink, sink, sink_id, obj.oid),
2099            ObjectType::Subscription => {
2100                rename_relation_ref!(Subscription, subscription, subscription_id, obj.oid)
2101            }
2102            ObjectType::View => rename_relation_ref!(View, view, view_id, obj.oid),
2103            ObjectType::Index => {
2104                let index_table_id: Option<TableId> = Index::find_by_id(obj.oid)
2105                    .select_only()
2106                    .column(index::Column::IndexTableId)
2107                    .into_tuple()
2108                    .one(txn)
2109                    .await?;
2110                rename_relation_ref!(Table, table, table_id, index_table_id.unwrap());
2111            }
2112            _ => {
2113                bail!(
2114                    "only the table, sink, subscription, view and index will depend on other objects."
2115                )
2116            }
2117        }
2118    }
2119
2120    Ok(to_update_relations)
2121}
2122
2123/// Validate that subscription can be safely deleted, meeting any of the following conditions:
2124/// 1. The upstream table is not referred to by any cross-db mv.
2125/// 2. After deleting the subscription, the upstream table still has at least one subscription.
2126pub async fn validate_subscription_deletion<C>(txn: &C, subscription_id: ObjectId) -> MetaResult<()>
2127where
2128    C: ConnectionTrait,
2129{
2130    let upstream_table_id: ObjectId = Subscription::find_by_id(subscription_id)
2131        .select_only()
2132        .column(subscription::Column::DependentTableId)
2133        .into_tuple()
2134        .one(txn)
2135        .await?
2136        .ok_or_else(|| MetaError::catalog_id_not_found("subscription", subscription_id))?;
2137
2138    let cnt = Subscription::find()
2139        .filter(subscription::Column::DependentTableId.eq(upstream_table_id))
2140        .count(txn)
2141        .await?;
2142    if cnt > 1 {
2143        // Ensure that at least one subscription is remained for the upstream table
2144        // once the subscription is dropped.
2145        return Ok(());
2146    }
2147
2148    // Ensure that the upstream table is not referred by any cross-db mv.
2149    let obj_alias = Alias::new("o1");
2150    let used_by_alias = Alias::new("o2");
2151    let count = ObjectDependency::find()
2152        .join_as(
2153            JoinType::InnerJoin,
2154            object_dependency::Relation::Object2.def(),
2155            obj_alias.clone(),
2156        )
2157        .join_as(
2158            JoinType::InnerJoin,
2159            object_dependency::Relation::Object1.def(),
2160            used_by_alias.clone(),
2161        )
2162        .filter(
2163            object_dependency::Column::Oid
2164                .eq(upstream_table_id)
2165                .and(object_dependency::Column::UsedBy.ne(subscription_id))
2166                .and(
2167                    Expr::col((obj_alias, object::Column::DatabaseId))
2168                        .ne(Expr::col((used_by_alias, object::Column::DatabaseId))),
2169                ),
2170        )
2171        .count(txn)
2172        .await?;
2173
2174    if count != 0 {
2175        return Err(MetaError::permission_denied(format!(
2176            "Referenced by {} cross-db objects.",
2177            count
2178        )));
2179    }
2180
2181    Ok(())
2182}
2183
2184pub async fn fetch_target_fragments<C>(
2185    txn: &C,
2186    src_fragment_id: impl IntoIterator<Item = FragmentId>,
2187) -> MetaResult<HashMap<FragmentId, Vec<FragmentId>>>
2188where
2189    C: ConnectionTrait,
2190{
2191    let source_target_fragments: Vec<(FragmentId, FragmentId)> = FragmentRelation::find()
2192        .select_only()
2193        .columns([
2194            fragment_relation::Column::SourceFragmentId,
2195            fragment_relation::Column::TargetFragmentId,
2196        ])
2197        .filter(fragment_relation::Column::SourceFragmentId.is_in(src_fragment_id))
2198        .into_tuple()
2199        .all(txn)
2200        .await?;
2201
2202    let source_target_fragments = source_target_fragments.into_iter().into_group_map();
2203
2204    Ok(source_target_fragments)
2205}
2206
2207pub async fn get_sink_fragment_by_ids<C>(
2208    txn: &C,
2209    sink_ids: Vec<SinkId>,
2210) -> MetaResult<HashMap<SinkId, FragmentId>>
2211where
2212    C: ConnectionTrait,
2213{
2214    let sink_num = sink_ids.len();
2215    let sink_fragment_ids: Vec<(SinkId, FragmentId)> = Fragment::find()
2216        .select_only()
2217        .columns([fragment::Column::JobId, fragment::Column::FragmentId])
2218        .filter(
2219            fragment::Column::JobId
2220                .is_in(sink_ids)
2221                .and(FragmentTypeMask::intersects(FragmentTypeFlag::Sink)),
2222        )
2223        .into_tuple()
2224        .all(txn)
2225        .await?;
2226
2227    if sink_fragment_ids.len() != sink_num {
2228        return Err(anyhow::anyhow!(
2229            "expected exactly one sink fragment for each sink, but got {} fragments for {} sinks",
2230            sink_fragment_ids.len(),
2231            sink_num
2232        )
2233        .into());
2234    }
2235
2236    Ok(sink_fragment_ids.into_iter().collect())
2237}
2238
2239pub async fn has_table_been_migrated<C>(txn: &C, table_id: TableId) -> MetaResult<bool>
2240where
2241    C: ConnectionTrait,
2242{
2243    let mview_fragment: Vec<i32> = Fragment::find()
2244        .select_only()
2245        .column(fragment::Column::FragmentTypeMask)
2246        .filter(
2247            fragment::Column::JobId
2248                .eq(table_id)
2249                .and(FragmentTypeMask::intersects(FragmentTypeFlag::Mview)),
2250        )
2251        .into_tuple()
2252        .all(txn)
2253        .await?;
2254
2255    let mview_fragment_len = mview_fragment.len();
2256    if mview_fragment_len != 1 {
2257        bail!(
2258            "expected exactly one mview fragment for table {}, found {}",
2259            table_id,
2260            mview_fragment_len
2261        );
2262    }
2263
2264    let mview_fragment = mview_fragment.into_iter().next().unwrap();
2265    let migrated =
2266        FragmentTypeMask::from(mview_fragment).contains(FragmentTypeFlag::UpstreamSinkUnion);
2267
2268    Ok(migrated)
2269}
2270
2271pub fn build_select_node_list(
2272    from: &[ColumnCatalog],
2273    to: &[ColumnCatalog],
2274) -> MetaResult<Vec<PbExprNode>> {
2275    let mut exprs = Vec::with_capacity(to.len());
2276    let idx_by_col_id = from
2277        .iter()
2278        .enumerate()
2279        .map(|(idx, col)| (col.column_desc.as_ref().unwrap().column_id, idx))
2280        .collect::<HashMap<_, _>>();
2281
2282    for to_col in to {
2283        let to_col = to_col.column_desc.as_ref().unwrap();
2284        let to_col_type = to_col.column_type.clone();
2285        if let Some(from_idx) = idx_by_col_id.get(&to_col.column_id) {
2286            let from_col_type = from[*from_idx]
2287                .column_desc
2288                .as_ref()
2289                .unwrap()
2290                .column_type
2291                .clone();
2292            if to_col_type != from_col_type {
2293                return Err(anyhow!(
2294                    "Column type mismatch: {:?} != {:?}",
2295                    from_col_type,
2296                    to_col_type
2297                )
2298                .into());
2299            }
2300            exprs.push(PbExprNode {
2301                function_type: expr_node::Type::Unspecified.into(),
2302                return_type: to_col_type,
2303                rex_node: Some(expr_node::RexNode::InputRef(*from_idx as _)),
2304            });
2305        } else {
2306            let to_default_node =
2307                if let Some(GeneratedOrDefaultColumn::DefaultColumn(DefaultColumnDesc {
2308                    expr,
2309                    ..
2310                })) = &to_col.generated_or_default_column
2311                {
2312                    expr.clone().unwrap()
2313                } else {
2314                    let null = Datum::None.to_protobuf();
2315                    PbExprNode {
2316                        function_type: expr_node::Type::Unspecified.into(),
2317                        return_type: to_col_type,
2318                        rex_node: Some(expr_node::RexNode::Constant(null)),
2319                    }
2320                };
2321            exprs.push(to_default_node);
2322        }
2323    }
2324
2325    Ok(exprs)
2326}
2327
2328#[cfg(test)]
2329mod tests {
2330    use super::*;
2331
2332    #[test]
2333    fn test_extract_cdc_table_name() {
2334        let ddl1 = "CREATE TABLE t1 () FROM pg_source TABLE 'public.t1'";
2335        let ddl2 = "CREATE TABLE t2 (v1 int) FROM pg_source TABLE 'mydb.t2'";
2336        assert_eq!(
2337            extract_external_table_name_from_definition(ddl1),
2338            Some("public.t1".into())
2339        );
2340        assert_eq!(
2341            extract_external_table_name_from_definition(ddl2),
2342            Some("mydb.t2".into())
2343        );
2344    }
2345}