risingwave_meta/controller/
utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::{BTreeSet, HashMap, HashSet};
17use std::sync::Arc;
18
19use anyhow::{Context, anyhow};
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
23use risingwave_common::hash::{ActorMapping, VnodeBitmapExt, WorkerSlotId, WorkerSlotMapping};
24use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
25use risingwave_common::{bail, hash};
26use risingwave_meta_model::actor::ActorStatus;
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::object::ObjectType;
29use risingwave_meta_model::prelude::*;
30use risingwave_meta_model::table::TableType;
31use risingwave_meta_model::user_privilege::Action;
32use risingwave_meta_model::{
33    ActorId, ColumnCatalogArray, DataTypeArray, DatabaseId, DispatcherType, FragmentId, I32Array,
34    JobStatus, ObjectId, PrivilegeId, SchemaId, SourceId, StreamNode, StreamSourceInfo, TableId,
35    UserId, VnodeBitmap, WorkerId, actor, connection, database, fragment, fragment_relation,
36    function, index, object, object_dependency, schema, secret, sink, source, streaming_job,
37    subscription, table, user, user_default_privilege, user_privilege, view,
38};
39use risingwave_meta_model_migration::WithQuery;
40use risingwave_pb::catalog::{
41    PbConnection, PbDatabase, PbFunction, PbIndex, PbSchema, PbSecret, PbSink, PbSource,
42    PbSubscription, PbTable, PbView,
43};
44use risingwave_pb::common::WorkerNode;
45use risingwave_pb::meta::object::PbObjectInfo;
46use risingwave_pb::meta::subscribe_response::Info as NotificationInfo;
47use risingwave_pb::meta::{
48    FragmentWorkerSlotMapping, PbFragmentWorkerSlotMapping, PbObject, PbObjectGroup,
49};
50use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher, PbDispatcherType};
51use risingwave_pb::user::grant_privilege::{PbActionWithGrantOption, PbObject as PbGrantObject};
52use risingwave_pb::user::{PbAction, PbGrantPrivilege, PbUserInfo};
53use risingwave_sqlparser::ast::Statement as SqlStatement;
54use risingwave_sqlparser::parser::Parser;
55use sea_orm::sea_query::{
56    Alias, CommonTableExpression, Expr, Query, QueryStatementBuilder, SelectStatement, UnionType,
57    WithClause,
58};
59use sea_orm::{
60    ColumnTrait, ConnectionTrait, DatabaseTransaction, DerivePartialModel, EntityTrait,
61    FromQueryResult, IntoActiveModel, JoinType, Order, PaginatorTrait, QueryFilter, QuerySelect,
62    RelationTrait, Set, Statement,
63};
64use thiserror_ext::AsReport;
65
66use crate::barrier::SharedFragmentInfo;
67use crate::controller::ObjectModel;
68use crate::model::{FragmentActorDispatchers, FragmentDownstreamRelation};
69use crate::{MetaError, MetaResult};
70
71/// This function will construct a query using recursive cte to find all objects[(id, `obj_type`)] that are used by the given object.
72///
73/// # Examples
74///
75/// ```
76/// use risingwave_meta::controller::utils::construct_obj_dependency_query;
77/// use sea_orm::sea_query::*;
78/// use sea_orm::*;
79///
80/// let query = construct_obj_dependency_query(1);
81///
82/// assert_eq!(
83///     query.to_string(MysqlQueryBuilder),
84///     r#"WITH RECURSIVE `used_by_object_ids` (`used_by`) AS (SELECT `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `oid` FROM `object` WHERE `object`.`database_id` = 1 OR `object`.`schema_id` = 1) UNION ALL (SELECT `object_dependency`.`used_by` FROM `object_dependency` INNER JOIN `used_by_object_ids` ON `used_by_object_ids`.`used_by` = `oid`)) SELECT DISTINCT `oid`, `obj_type`, `schema_id`, `database_id` FROM `used_by_object_ids` INNER JOIN `object` ON `used_by_object_ids`.`used_by` = `oid` ORDER BY `oid` DESC"#
85/// );
86/// assert_eq!(
87///     query.to_string(PostgresQueryBuilder),
88///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1) UNION ALL (SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid")) SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
89/// );
90/// assert_eq!(
91///     query.to_string(SqliteQueryBuilder),
92///     r#"WITH RECURSIVE "used_by_object_ids" ("used_by") AS (SELECT "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "oid" FROM "object" WHERE "object"."database_id" = 1 OR "object"."schema_id" = 1 UNION ALL SELECT "object_dependency"."used_by" FROM "object_dependency" INNER JOIN "used_by_object_ids" ON "used_by_object_ids"."used_by" = "oid") SELECT DISTINCT "oid", "obj_type", "schema_id", "database_id" FROM "used_by_object_ids" INNER JOIN "object" ON "used_by_object_ids"."used_by" = "oid" ORDER BY "oid" DESC"#
93/// );
94/// ```
95pub fn construct_obj_dependency_query(obj_id: ObjectId) -> WithQuery {
96    let cte_alias = Alias::new("used_by_object_ids");
97    let cte_return_alias = Alias::new("used_by");
98
99    let mut base_query = SelectStatement::new()
100        .column(object_dependency::Column::UsedBy)
101        .from(ObjectDependency)
102        .and_where(object_dependency::Column::Oid.eq(obj_id))
103        .to_owned();
104
105    let belonged_obj_query = SelectStatement::new()
106        .column(object::Column::Oid)
107        .from(Object)
108        .and_where(
109            object::Column::DatabaseId
110                .eq(obj_id)
111                .or(object::Column::SchemaId.eq(obj_id)),
112        )
113        .to_owned();
114
115    let cte_referencing = Query::select()
116        .column((ObjectDependency, object_dependency::Column::UsedBy))
117        .from(ObjectDependency)
118        .inner_join(
119            cte_alias.clone(),
120            Expr::col((cte_alias.clone(), cte_return_alias.clone()))
121                .equals(object_dependency::Column::Oid),
122        )
123        .to_owned();
124
125    let common_table_expr = CommonTableExpression::new()
126        .query(
127            base_query
128                .union(UnionType::All, belonged_obj_query)
129                .union(UnionType::All, cte_referencing)
130                .to_owned(),
131        )
132        .column(cte_return_alias.clone())
133        .table_name(cte_alias.clone())
134        .to_owned();
135
136    SelectStatement::new()
137        .distinct()
138        .columns([
139            object::Column::Oid,
140            object::Column::ObjType,
141            object::Column::SchemaId,
142            object::Column::DatabaseId,
143        ])
144        .from(cte_alias.clone())
145        .inner_join(
146            Object,
147            Expr::col((cte_alias, cte_return_alias.clone())).equals(object::Column::Oid),
148        )
149        .order_by(object::Column::Oid, Order::Desc)
150        .to_owned()
151        .with(
152            WithClause::new()
153                .recursive(true)
154                .cte(common_table_expr)
155                .to_owned(),
156        )
157        .to_owned()
158}
159
160/// This function will construct a query using recursive cte to find if dependent objects are already relying on the target table.
161///
162/// # Examples
163///
164/// ```
165/// use risingwave_meta::controller::utils::construct_sink_cycle_check_query;
166/// use sea_orm::sea_query::*;
167/// use sea_orm::*;
168///
169/// let query = construct_sink_cycle_check_query(1, vec![2, 3]);
170///
171/// assert_eq!(
172///     query.to_string(MysqlQueryBuilder),
173///     r#"WITH RECURSIVE `used_by_object_ids_with_sink` (`oid`, `used_by`) AS (SELECT `oid`, `used_by` FROM `object_dependency` WHERE `object_dependency`.`oid` = 1 UNION ALL (SELECT `obj_dependency_with_sink`.`oid`, `obj_dependency_with_sink`.`used_by` FROM (SELECT `oid`, `used_by` FROM `object_dependency` UNION ALL (SELECT `sink_id`, `target_table` FROM `sink` WHERE `sink`.`target_table` IS NOT NULL)) AS `obj_dependency_with_sink` INNER JOIN `used_by_object_ids_with_sink` ON `used_by_object_ids_with_sink`.`used_by` = `obj_dependency_with_sink`.`oid` WHERE `used_by_object_ids_with_sink`.`used_by` <> `used_by_object_ids_with_sink`.`oid`)) SELECT COUNT(`used_by_object_ids_with_sink`.`used_by`) FROM `used_by_object_ids_with_sink` WHERE `used_by_object_ids_with_sink`.`used_by` IN (2, 3)"#
174/// );
175/// assert_eq!(
176///     query.to_string(PostgresQueryBuilder),
177///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL (SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL (SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL)) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid")) SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
178/// );
179/// assert_eq!(
180///     query.to_string(SqliteQueryBuilder),
181///     r#"WITH RECURSIVE "used_by_object_ids_with_sink" ("oid", "used_by") AS (SELECT "oid", "used_by" FROM "object_dependency" WHERE "object_dependency"."oid" = 1 UNION ALL SELECT "obj_dependency_with_sink"."oid", "obj_dependency_with_sink"."used_by" FROM (SELECT "oid", "used_by" FROM "object_dependency" UNION ALL SELECT "sink_id", "target_table" FROM "sink" WHERE "sink"."target_table" IS NOT NULL) AS "obj_dependency_with_sink" INNER JOIN "used_by_object_ids_with_sink" ON "used_by_object_ids_with_sink"."used_by" = "obj_dependency_with_sink"."oid" WHERE "used_by_object_ids_with_sink"."used_by" <> "used_by_object_ids_with_sink"."oid") SELECT COUNT("used_by_object_ids_with_sink"."used_by") FROM "used_by_object_ids_with_sink" WHERE "used_by_object_ids_with_sink"."used_by" IN (2, 3)"#
182/// );
183/// ```
184pub fn construct_sink_cycle_check_query(
185    target_table: ObjectId,
186    dependent_objects: Vec<ObjectId>,
187) -> WithQuery {
188    let cte_alias = Alias::new("used_by_object_ids_with_sink");
189    let depend_alias = Alias::new("obj_dependency_with_sink");
190
191    let mut base_query = SelectStatement::new()
192        .columns([
193            object_dependency::Column::Oid,
194            object_dependency::Column::UsedBy,
195        ])
196        .from(ObjectDependency)
197        .and_where(object_dependency::Column::Oid.eq(target_table))
198        .to_owned();
199
200    let query_sink_deps = SelectStatement::new()
201        .columns([sink::Column::SinkId, sink::Column::TargetTable])
202        .from(Sink)
203        .and_where(sink::Column::TargetTable.is_not_null())
204        .to_owned();
205
206    let cte_referencing = Query::select()
207        .column((depend_alias.clone(), object_dependency::Column::Oid))
208        .column((depend_alias.clone(), object_dependency::Column::UsedBy))
209        .from_subquery(
210            SelectStatement::new()
211                .columns([
212                    object_dependency::Column::Oid,
213                    object_dependency::Column::UsedBy,
214                ])
215                .from(ObjectDependency)
216                .union(UnionType::All, query_sink_deps)
217                .to_owned(),
218            depend_alias.clone(),
219        )
220        .inner_join(
221            cte_alias.clone(),
222            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).eq(Expr::col((
223                depend_alias.clone(),
224                object_dependency::Column::Oid,
225            ))),
226        )
227        .and_where(
228            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).ne(Expr::col((
229                cte_alias.clone(),
230                object_dependency::Column::Oid,
231            ))),
232        )
233        .to_owned();
234
235    let common_table_expr = CommonTableExpression::new()
236        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
237        .columns([
238            object_dependency::Column::Oid,
239            object_dependency::Column::UsedBy,
240        ])
241        .table_name(cte_alias.clone())
242        .to_owned();
243
244    SelectStatement::new()
245        .expr(Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy)).count())
246        .from(cte_alias.clone())
247        .and_where(
248            Expr::col((cte_alias.clone(), object_dependency::Column::UsedBy))
249                .is_in(dependent_objects),
250        )
251        .to_owned()
252        .with(
253            WithClause::new()
254                .recursive(true)
255                .cte(common_table_expr)
256                .to_owned(),
257        )
258        .to_owned()
259}
260
261#[derive(Clone, DerivePartialModel, FromQueryResult, Debug)]
262#[sea_orm(entity = "Object")]
263pub struct PartialObject {
264    pub oid: ObjectId,
265    pub obj_type: ObjectType,
266    pub schema_id: Option<SchemaId>,
267    pub database_id: Option<DatabaseId>,
268}
269
270#[derive(Clone, DerivePartialModel, FromQueryResult)]
271#[sea_orm(entity = "Fragment")]
272pub struct PartialFragmentStateTables {
273    pub fragment_id: FragmentId,
274    pub job_id: ObjectId,
275    pub state_table_ids: I32Array,
276}
277
278#[derive(Clone, DerivePartialModel, FromQueryResult)]
279#[sea_orm(entity = "Actor")]
280pub struct PartialActorLocation {
281    pub actor_id: ActorId,
282    pub fragment_id: FragmentId,
283    pub worker_id: WorkerId,
284    pub status: ActorStatus,
285}
286
287#[derive(FromQueryResult)]
288pub struct FragmentDesc {
289    pub fragment_id: FragmentId,
290    pub job_id: ObjectId,
291    pub fragment_type_mask: i32,
292    pub distribution_type: DistributionType,
293    pub state_table_ids: I32Array,
294    pub parallelism: i64,
295    pub vnode_count: i32,
296    pub stream_node: StreamNode,
297}
298
299/// List all objects that are using the given one in a cascade way. It runs a recursive CTE to find all the dependencies.
300pub async fn get_referring_objects_cascade<C>(
301    obj_id: ObjectId,
302    db: &C,
303) -> MetaResult<Vec<PartialObject>>
304where
305    C: ConnectionTrait,
306{
307    let query = construct_obj_dependency_query(obj_id);
308    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
309    let objects = PartialObject::find_by_statement(Statement::from_sql_and_values(
310        db.get_database_backend(),
311        sql,
312        values,
313    ))
314    .all(db)
315    .await?;
316    Ok(objects)
317}
318
319/// Check if create a sink with given dependent objects into the target table will cause a cycle, return true if it will.
320pub async fn check_sink_into_table_cycle<C>(
321    target_table: ObjectId,
322    dependent_objs: Vec<ObjectId>,
323    db: &C,
324) -> MetaResult<bool>
325where
326    C: ConnectionTrait,
327{
328    if dependent_objs.is_empty() {
329        return Ok(false);
330    }
331
332    // special check for self referencing
333    if dependent_objs.contains(&target_table) {
334        return Ok(true);
335    }
336
337    let query = construct_sink_cycle_check_query(target_table, dependent_objs);
338    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
339
340    let res = db
341        .query_one(Statement::from_sql_and_values(
342            db.get_database_backend(),
343            sql,
344            values,
345        ))
346        .await?
347        .unwrap();
348
349    let cnt: i64 = res.try_get_by(0)?;
350
351    Ok(cnt != 0)
352}
353
354/// `ensure_object_id` ensures the existence of target object in the cluster.
355pub async fn ensure_object_id<C>(
356    object_type: ObjectType,
357    obj_id: ObjectId,
358    db: &C,
359) -> MetaResult<()>
360where
361    C: ConnectionTrait,
362{
363    let count = Object::find_by_id(obj_id).count(db).await?;
364    if count == 0 {
365        return Err(MetaError::catalog_id_not_found(
366            object_type.as_str(),
367            obj_id,
368        ));
369    }
370    Ok(())
371}
372
373/// `ensure_user_id` ensures the existence of target user in the cluster.
374pub async fn ensure_user_id<C>(user_id: UserId, db: &C) -> MetaResult<()>
375where
376    C: ConnectionTrait,
377{
378    let count = User::find_by_id(user_id).count(db).await?;
379    if count == 0 {
380        return Err(anyhow!("user {} was concurrently dropped", user_id).into());
381    }
382    Ok(())
383}
384
385/// `check_database_name_duplicate` checks whether the database name is already used in the cluster.
386pub async fn check_database_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
387where
388    C: ConnectionTrait,
389{
390    let count = Database::find()
391        .filter(database::Column::Name.eq(name))
392        .count(db)
393        .await?;
394    if count > 0 {
395        assert_eq!(count, 1);
396        return Err(MetaError::catalog_duplicated("database", name));
397    }
398    Ok(())
399}
400
401/// `check_function_signature_duplicate` checks whether the function name and its signature is already used in the target namespace.
402pub async fn check_function_signature_duplicate<C>(
403    pb_function: &PbFunction,
404    db: &C,
405) -> MetaResult<()>
406where
407    C: ConnectionTrait,
408{
409    let count = Function::find()
410        .inner_join(Object)
411        .filter(
412            object::Column::DatabaseId
413                .eq(pb_function.database_id as DatabaseId)
414                .and(object::Column::SchemaId.eq(pb_function.schema_id as SchemaId))
415                .and(function::Column::Name.eq(&pb_function.name))
416                .and(
417                    function::Column::ArgTypes
418                        .eq(DataTypeArray::from(pb_function.arg_types.clone())),
419                ),
420        )
421        .count(db)
422        .await?;
423    if count > 0 {
424        assert_eq!(count, 1);
425        return Err(MetaError::catalog_duplicated("function", &pb_function.name));
426    }
427    Ok(())
428}
429
430/// `check_connection_name_duplicate` checks whether the connection name is already used in the target namespace.
431pub async fn check_connection_name_duplicate<C>(
432    pb_connection: &PbConnection,
433    db: &C,
434) -> MetaResult<()>
435where
436    C: ConnectionTrait,
437{
438    let count = Connection::find()
439        .inner_join(Object)
440        .filter(
441            object::Column::DatabaseId
442                .eq(pb_connection.database_id as DatabaseId)
443                .and(object::Column::SchemaId.eq(pb_connection.schema_id as SchemaId))
444                .and(connection::Column::Name.eq(&pb_connection.name)),
445        )
446        .count(db)
447        .await?;
448    if count > 0 {
449        assert_eq!(count, 1);
450        return Err(MetaError::catalog_duplicated(
451            "connection",
452            &pb_connection.name,
453        ));
454    }
455    Ok(())
456}
457
458pub async fn check_secret_name_duplicate<C>(pb_secret: &PbSecret, db: &C) -> MetaResult<()>
459where
460    C: ConnectionTrait,
461{
462    let count = Secret::find()
463        .inner_join(Object)
464        .filter(
465            object::Column::DatabaseId
466                .eq(pb_secret.database_id as DatabaseId)
467                .and(object::Column::SchemaId.eq(pb_secret.schema_id as SchemaId))
468                .and(secret::Column::Name.eq(&pb_secret.name)),
469        )
470        .count(db)
471        .await?;
472    if count > 0 {
473        assert_eq!(count, 1);
474        return Err(MetaError::catalog_duplicated("secret", &pb_secret.name));
475    }
476    Ok(())
477}
478
479pub async fn check_subscription_name_duplicate<C>(
480    pb_subscription: &PbSubscription,
481    db: &C,
482) -> MetaResult<()>
483where
484    C: ConnectionTrait,
485{
486    let count = Subscription::find()
487        .inner_join(Object)
488        .filter(
489            object::Column::DatabaseId
490                .eq(pb_subscription.database_id as DatabaseId)
491                .and(object::Column::SchemaId.eq(pb_subscription.schema_id as SchemaId))
492                .and(subscription::Column::Name.eq(&pb_subscription.name)),
493        )
494        .count(db)
495        .await?;
496    if count > 0 {
497        assert_eq!(count, 1);
498        return Err(MetaError::catalog_duplicated(
499            "subscription",
500            &pb_subscription.name,
501        ));
502    }
503    Ok(())
504}
505
506/// `check_user_name_duplicate` checks whether the user is already existed in the cluster.
507pub async fn check_user_name_duplicate<C>(name: &str, db: &C) -> MetaResult<()>
508where
509    C: ConnectionTrait,
510{
511    let count = User::find()
512        .filter(user::Column::Name.eq(name))
513        .count(db)
514        .await?;
515    if count > 0 {
516        assert_eq!(count, 1);
517        return Err(MetaError::catalog_duplicated("user", name));
518    }
519    Ok(())
520}
521
522/// `check_relation_name_duplicate` checks whether the relation name is already used in the target namespace.
523pub async fn check_relation_name_duplicate<C>(
524    name: &str,
525    database_id: DatabaseId,
526    schema_id: SchemaId,
527    db: &C,
528) -> MetaResult<()>
529where
530    C: ConnectionTrait,
531{
532    macro_rules! check_duplicated {
533        ($obj_type:expr, $entity:ident, $table:ident) => {
534            let object_id = Object::find()
535                .select_only()
536                .column(object::Column::Oid)
537                .inner_join($entity)
538                .filter(
539                    object::Column::DatabaseId
540                        .eq(Some(database_id))
541                        .and(object::Column::SchemaId.eq(Some(schema_id)))
542                        .and($table::Column::Name.eq(name)),
543                )
544                .into_tuple::<ObjectId>()
545                .one(db)
546                .await?;
547            if let Some(oid) = object_id {
548                let check_creation = if $obj_type == ObjectType::View {
549                    false
550                } else if $obj_type == ObjectType::Source {
551                    let source_info = Source::find_by_id(oid)
552                        .select_only()
553                        .column(source::Column::SourceInfo)
554                        .into_tuple::<Option<StreamSourceInfo>>()
555                        .one(db)
556                        .await?
557                        .unwrap();
558                    source_info.map_or(false, |info| info.to_protobuf().is_shared())
559                } else {
560                    true
561                };
562                return if check_creation
563                    && !matches!(
564                        StreamingJob::find_by_id(oid)
565                            .select_only()
566                            .column(streaming_job::Column::JobStatus)
567                            .into_tuple::<JobStatus>()
568                            .one(db)
569                            .await?,
570                        Some(JobStatus::Created)
571                    ) {
572                    Err(MetaError::catalog_under_creation(
573                        $obj_type.as_str(),
574                        name,
575                        oid,
576                    ))
577                } else {
578                    Err(MetaError::catalog_duplicated($obj_type.as_str(), name))
579                };
580            }
581        };
582    }
583    check_duplicated!(ObjectType::Table, Table, table);
584    check_duplicated!(ObjectType::Source, Source, source);
585    check_duplicated!(ObjectType::Sink, Sink, sink);
586    check_duplicated!(ObjectType::Index, Index, index);
587    check_duplicated!(ObjectType::View, View, view);
588
589    Ok(())
590}
591
592/// `check_schema_name_duplicate` checks whether the schema name is already used in the target database.
593pub async fn check_schema_name_duplicate<C>(
594    name: &str,
595    database_id: DatabaseId,
596    db: &C,
597) -> MetaResult<()>
598where
599    C: ConnectionTrait,
600{
601    let count = Object::find()
602        .inner_join(Schema)
603        .filter(
604            object::Column::ObjType
605                .eq(ObjectType::Schema)
606                .and(object::Column::DatabaseId.eq(Some(database_id)))
607                .and(schema::Column::Name.eq(name)),
608        )
609        .count(db)
610        .await?;
611    if count != 0 {
612        return Err(MetaError::catalog_duplicated("schema", name));
613    }
614
615    Ok(())
616}
617
618/// `check_object_refer_for_drop` checks whether the object is used by other objects except indexes.
619/// It returns an error that contains the details of the referring objects if it is used by others.
620pub async fn check_object_refer_for_drop<C>(
621    object_type: ObjectType,
622    object_id: ObjectId,
623    db: &C,
624) -> MetaResult<()>
625where
626    C: ConnectionTrait,
627{
628    // Ignore indexes.
629    let count = if object_type == ObjectType::Table {
630        ObjectDependency::find()
631            .join(
632                JoinType::InnerJoin,
633                object_dependency::Relation::Object1.def(),
634            )
635            .filter(
636                object_dependency::Column::Oid
637                    .eq(object_id)
638                    .and(object::Column::ObjType.ne(ObjectType::Index)),
639            )
640            .count(db)
641            .await?
642    } else {
643        ObjectDependency::find()
644            .filter(object_dependency::Column::Oid.eq(object_id))
645            .count(db)
646            .await?
647    };
648    if count != 0 {
649        // find the name of all objects that are using the given one.
650        let referring_objects = get_referring_objects(object_id, db).await?;
651        let referring_objs_map = referring_objects
652            .into_iter()
653            .filter(|o| o.obj_type != ObjectType::Index)
654            .into_group_map_by(|o| o.obj_type);
655        let mut details = vec![];
656        for (obj_type, objs) in referring_objs_map {
657            match obj_type {
658                ObjectType::Table => {
659                    let tables: Vec<(String, String)> = Object::find()
660                        .join(JoinType::InnerJoin, object::Relation::Table.def())
661                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
662                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
663                        .select_only()
664                        .column(schema::Column::Name)
665                        .column(table::Column::Name)
666                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
667                        .into_tuple()
668                        .all(db)
669                        .await?;
670                    details.extend(tables.into_iter().map(|(schema_name, table_name)| {
671                        format!(
672                            "materialized view {}.{} depends on it",
673                            schema_name, table_name
674                        )
675                    }));
676                }
677                ObjectType::Sink => {
678                    let sinks: Vec<(String, String)> = Object::find()
679                        .join(JoinType::InnerJoin, object::Relation::Sink.def())
680                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
681                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
682                        .select_only()
683                        .column(schema::Column::Name)
684                        .column(sink::Column::Name)
685                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
686                        .into_tuple()
687                        .all(db)
688                        .await?;
689                    details.extend(sinks.into_iter().map(|(schema_name, sink_name)| {
690                        format!("sink {}.{} depends on it", schema_name, sink_name)
691                    }));
692                }
693                ObjectType::View => {
694                    let views: Vec<(String, String)> = Object::find()
695                        .join(JoinType::InnerJoin, object::Relation::View.def())
696                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
697                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
698                        .select_only()
699                        .column(schema::Column::Name)
700                        .column(view::Column::Name)
701                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
702                        .into_tuple()
703                        .all(db)
704                        .await?;
705                    details.extend(views.into_iter().map(|(schema_name, view_name)| {
706                        format!("view {}.{} depends on it", schema_name, view_name)
707                    }));
708                }
709                ObjectType::Subscription => {
710                    let subscriptions: Vec<(String, String)> = Object::find()
711                        .join(JoinType::InnerJoin, object::Relation::Subscription.def())
712                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
713                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
714                        .select_only()
715                        .column(schema::Column::Name)
716                        .column(subscription::Column::Name)
717                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
718                        .into_tuple()
719                        .all(db)
720                        .await?;
721                    details.extend(subscriptions.into_iter().map(
722                        |(schema_name, subscription_name)| {
723                            format!(
724                                "subscription {}.{} depends on it",
725                                schema_name, subscription_name
726                            )
727                        },
728                    ));
729                }
730                ObjectType::Source => {
731                    let sources: Vec<(String, String)> = Object::find()
732                        .join(JoinType::InnerJoin, object::Relation::Source.def())
733                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
734                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
735                        .select_only()
736                        .column(schema::Column::Name)
737                        .column(source::Column::Name)
738                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
739                        .into_tuple()
740                        .all(db)
741                        .await?;
742                    details.extend(sources.into_iter().map(|(schema_name, view_name)| {
743                        format!("source {}.{} depends on it", schema_name, view_name)
744                    }));
745                }
746                ObjectType::Connection => {
747                    let connections: Vec<(String, String)> = Object::find()
748                        .join(JoinType::InnerJoin, object::Relation::Connection.def())
749                        .join(JoinType::InnerJoin, object::Relation::Database2.def())
750                        .join(JoinType::InnerJoin, object::Relation::Schema2.def())
751                        .select_only()
752                        .column(schema::Column::Name)
753                        .column(connection::Column::Name)
754                        .filter(object::Column::Oid.is_in(objs.iter().map(|o| o.oid)))
755                        .into_tuple()
756                        .all(db)
757                        .await?;
758                    details.extend(connections.into_iter().map(|(schema_name, view_name)| {
759                        format!("connection {}.{} depends on it", schema_name, view_name)
760                    }));
761                }
762                // only the table, source, sink, subscription, view, connection and index will depend on other objects.
763                _ => bail!("unexpected referring object type: {}", obj_type.as_str()),
764            }
765        }
766
767        return Err(MetaError::permission_denied(format!(
768            "{} used by {} other objects. \nDETAIL: {}\n\
769            {}",
770            object_type.as_str(),
771            count,
772            details.join("\n"),
773            match object_type {
774                ObjectType::Function | ObjectType::Connection | ObjectType::Secret =>
775                    "HINT: DROP the dependent objects first.",
776                ObjectType::Database | ObjectType::Schema => unreachable!(),
777                _ => "HINT:  Use DROP ... CASCADE to drop the dependent objects too.",
778            }
779        )));
780    }
781    Ok(())
782}
783
784/// List all objects that are using the given one.
785pub async fn get_referring_objects<C>(object_id: ObjectId, db: &C) -> MetaResult<Vec<PartialObject>>
786where
787    C: ConnectionTrait,
788{
789    let objs = ObjectDependency::find()
790        .filter(object_dependency::Column::Oid.eq(object_id))
791        .join(
792            JoinType::InnerJoin,
793            object_dependency::Relation::Object1.def(),
794        )
795        .into_partial_model()
796        .all(db)
797        .await?;
798
799    Ok(objs)
800}
801
802/// `ensure_schema_empty` ensures that the schema is empty, used by `DROP SCHEMA`.
803pub async fn ensure_schema_empty<C>(schema_id: SchemaId, db: &C) -> MetaResult<()>
804where
805    C: ConnectionTrait,
806{
807    let count = Object::find()
808        .filter(object::Column::SchemaId.eq(Some(schema_id)))
809        .count(db)
810        .await?;
811    if count != 0 {
812        return Err(MetaError::permission_denied("schema is not empty"));
813    }
814
815    Ok(())
816}
817
818/// `list_user_info_by_ids` lists all users' info by their ids.
819pub async fn list_user_info_by_ids<C>(
820    user_ids: impl IntoIterator<Item = UserId>,
821    db: &C,
822) -> MetaResult<Vec<PbUserInfo>>
823where
824    C: ConnectionTrait,
825{
826    let mut user_infos = vec![];
827    for user_id in user_ids {
828        let user = User::find_by_id(user_id)
829            .one(db)
830            .await?
831            .ok_or_else(|| MetaError::catalog_id_not_found("user", user_id))?;
832        let mut user_info: PbUserInfo = user.into();
833        user_info.grant_privileges = get_user_privilege(user_id, db).await?;
834        user_infos.push(user_info);
835    }
836    Ok(user_infos)
837}
838
839/// `get_object_owner` returns the owner of the given object.
840pub async fn get_object_owner<C>(object_id: ObjectId, db: &C) -> MetaResult<UserId>
841where
842    C: ConnectionTrait,
843{
844    let obj_owner: UserId = Object::find_by_id(object_id)
845        .select_only()
846        .column(object::Column::OwnerId)
847        .into_tuple()
848        .one(db)
849        .await?
850        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
851    Ok(obj_owner)
852}
853
854/// `construct_privilege_dependency_query` constructs a query to find all privileges that are dependent on the given one.
855///
856/// # Examples
857///
858/// ```
859/// use risingwave_meta::controller::utils::construct_privilege_dependency_query;
860/// use sea_orm::sea_query::*;
861/// use sea_orm::*;
862///
863/// let query = construct_privilege_dependency_query(vec![1, 2, 3]);
864///
865/// assert_eq!(
866///    query.to_string(MysqlQueryBuilder),
867///   r#"WITH RECURSIVE `granted_privilege_ids` (`id`, `user_id`) AS (SELECT `id`, `user_id` FROM `user_privilege` WHERE `user_privilege`.`id` IN (1, 2, 3) UNION ALL (SELECT `user_privilege`.`id`, `user_privilege`.`user_id` FROM `user_privilege` INNER JOIN `granted_privilege_ids` ON `granted_privilege_ids`.`id` = `dependent_id`)) SELECT `id`, `user_id` FROM `granted_privilege_ids`"#
868/// );
869/// assert_eq!(
870///   query.to_string(PostgresQueryBuilder),
871///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL (SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id")) SELECT "id", "user_id" FROM "granted_privilege_ids""#
872/// );
873/// assert_eq!(
874///  query.to_string(SqliteQueryBuilder),
875///  r#"WITH RECURSIVE "granted_privilege_ids" ("id", "user_id") AS (SELECT "id", "user_id" FROM "user_privilege" WHERE "user_privilege"."id" IN (1, 2, 3) UNION ALL SELECT "user_privilege"."id", "user_privilege"."user_id" FROM "user_privilege" INNER JOIN "granted_privilege_ids" ON "granted_privilege_ids"."id" = "dependent_id") SELECT "id", "user_id" FROM "granted_privilege_ids""#
876/// );
877/// ```
878pub fn construct_privilege_dependency_query(ids: Vec<PrivilegeId>) -> WithQuery {
879    let cte_alias = Alias::new("granted_privilege_ids");
880    let cte_return_privilege_alias = Alias::new("id");
881    let cte_return_user_alias = Alias::new("user_id");
882
883    let mut base_query = SelectStatement::new()
884        .columns([user_privilege::Column::Id, user_privilege::Column::UserId])
885        .from(UserPrivilege)
886        .and_where(user_privilege::Column::Id.is_in(ids))
887        .to_owned();
888
889    let cte_referencing = Query::select()
890        .columns([
891            (UserPrivilege, user_privilege::Column::Id),
892            (UserPrivilege, user_privilege::Column::UserId),
893        ])
894        .from(UserPrivilege)
895        .inner_join(
896            cte_alias.clone(),
897            Expr::col((cte_alias.clone(), cte_return_privilege_alias.clone()))
898                .equals(user_privilege::Column::DependentId),
899        )
900        .to_owned();
901
902    let common_table_expr = CommonTableExpression::new()
903        .query(base_query.union(UnionType::All, cte_referencing).to_owned())
904        .columns([
905            cte_return_privilege_alias.clone(),
906            cte_return_user_alias.clone(),
907        ])
908        .table_name(cte_alias.clone())
909        .to_owned();
910
911    SelectStatement::new()
912        .columns([cte_return_privilege_alias, cte_return_user_alias])
913        .from(cte_alias.clone())
914        .to_owned()
915        .with(
916            WithClause::new()
917                .recursive(true)
918                .cte(common_table_expr)
919                .to_owned(),
920        )
921        .to_owned()
922}
923
924pub async fn get_internal_tables_by_id<C>(job_id: ObjectId, db: &C) -> MetaResult<Vec<TableId>>
925where
926    C: ConnectionTrait,
927{
928    let table_ids: Vec<TableId> = Table::find()
929        .select_only()
930        .column(table::Column::TableId)
931        .filter(
932            table::Column::TableType
933                .eq(TableType::Internal)
934                .and(table::Column::BelongsToJobId.eq(job_id)),
935        )
936        .into_tuple()
937        .all(db)
938        .await?;
939    Ok(table_ids)
940}
941
942pub async fn get_index_state_tables_by_table_id<C>(
943    table_id: TableId,
944    db: &C,
945) -> MetaResult<Vec<TableId>>
946where
947    C: ConnectionTrait,
948{
949    let mut index_table_ids: Vec<TableId> = Index::find()
950        .select_only()
951        .column(index::Column::IndexTableId)
952        .filter(index::Column::PrimaryTableId.eq(table_id))
953        .into_tuple()
954        .all(db)
955        .await?;
956
957    if !index_table_ids.is_empty() {
958        let internal_table_ids: Vec<TableId> = Table::find()
959            .select_only()
960            .column(table::Column::TableId)
961            .filter(
962                table::Column::TableType
963                    .eq(TableType::Internal)
964                    .and(table::Column::BelongsToJobId.is_in(index_table_ids.clone())),
965            )
966            .into_tuple()
967            .all(db)
968            .await?;
969
970        index_table_ids.extend(internal_table_ids.into_iter());
971    }
972
973    Ok(index_table_ids)
974}
975
976#[derive(Clone, DerivePartialModel, FromQueryResult)]
977#[sea_orm(entity = "UserPrivilege")]
978pub struct PartialUserPrivilege {
979    pub id: PrivilegeId,
980    pub user_id: UserId,
981}
982
983pub async fn get_referring_privileges_cascade<C>(
984    ids: Vec<PrivilegeId>,
985    db: &C,
986) -> MetaResult<Vec<PartialUserPrivilege>>
987where
988    C: ConnectionTrait,
989{
990    let query = construct_privilege_dependency_query(ids);
991    let (sql, values) = query.build_any(&*db.get_database_backend().get_query_builder());
992    let privileges = PartialUserPrivilege::find_by_statement(Statement::from_sql_and_values(
993        db.get_database_backend(),
994        sql,
995        values,
996    ))
997    .all(db)
998    .await?;
999
1000    Ok(privileges)
1001}
1002
1003/// `ensure_privileges_not_referred` ensures that the privileges are not granted to any other users.
1004pub async fn ensure_privileges_not_referred<C>(ids: Vec<PrivilegeId>, db: &C) -> MetaResult<()>
1005where
1006    C: ConnectionTrait,
1007{
1008    let count = UserPrivilege::find()
1009        .filter(user_privilege::Column::DependentId.is_in(ids))
1010        .count(db)
1011        .await?;
1012    if count != 0 {
1013        return Err(MetaError::permission_denied(format!(
1014            "privileges granted to {} other ones.",
1015            count
1016        )));
1017    }
1018    Ok(())
1019}
1020
1021/// `get_user_privilege` returns the privileges of the given user.
1022pub async fn get_user_privilege<C>(user_id: UserId, db: &C) -> MetaResult<Vec<PbGrantPrivilege>>
1023where
1024    C: ConnectionTrait,
1025{
1026    let user_privileges = UserPrivilege::find()
1027        .find_also_related(Object)
1028        .filter(user_privilege::Column::UserId.eq(user_id))
1029        .all(db)
1030        .await?;
1031    Ok(user_privileges
1032        .into_iter()
1033        .map(|(privilege, object)| {
1034            let object = object.unwrap();
1035            let oid = object.oid as _;
1036            let obj = match object.obj_type {
1037                ObjectType::Database => PbGrantObject::DatabaseId(oid),
1038                ObjectType::Schema => PbGrantObject::SchemaId(oid),
1039                ObjectType::Table | ObjectType::Index => PbGrantObject::TableId(oid),
1040                ObjectType::Source => PbGrantObject::SourceId(oid),
1041                ObjectType::Sink => PbGrantObject::SinkId(oid),
1042                ObjectType::View => PbGrantObject::ViewId(oid),
1043                ObjectType::Function => PbGrantObject::FunctionId(oid),
1044                ObjectType::Connection => PbGrantObject::ConnectionId(oid),
1045                ObjectType::Subscription => PbGrantObject::SubscriptionId(oid),
1046                ObjectType::Secret => PbGrantObject::SecretId(oid),
1047            };
1048            PbGrantPrivilege {
1049                action_with_opts: vec![PbActionWithGrantOption {
1050                    action: PbAction::from(privilege.action) as _,
1051                    with_grant_option: privilege.with_grant_option,
1052                    granted_by: privilege.granted_by as _,
1053                }],
1054                object: Some(obj),
1055            }
1056        })
1057        .collect())
1058}
1059
1060pub async fn get_table_columns(
1061    txn: &impl ConnectionTrait,
1062    id: TableId,
1063) -> MetaResult<ColumnCatalogArray> {
1064    let columns = Table::find_by_id(id)
1065        .select_only()
1066        .columns([table::Column::Columns])
1067        .into_tuple::<ColumnCatalogArray>()
1068        .one(txn)
1069        .await?
1070        .ok_or_else(|| MetaError::catalog_id_not_found("table", id))?;
1071    Ok(columns)
1072}
1073
1074/// `grant_default_privileges_automatically` grants default privileges automatically
1075/// for the given new object. It returns the list of user infos whose privileges are updated.
1076pub async fn grant_default_privileges_automatically<C>(
1077    db: &C,
1078    object_id: ObjectId,
1079) -> MetaResult<Vec<PbUserInfo>>
1080where
1081    C: ConnectionTrait,
1082{
1083    let object = Object::find_by_id(object_id)
1084        .one(db)
1085        .await?
1086        .ok_or_else(|| MetaError::catalog_id_not_found("object", object_id))?;
1087    assert_ne!(object.obj_type, ObjectType::Database);
1088
1089    let for_mview_filter = if object.obj_type == ObjectType::Table {
1090        let table_type = Table::find_by_id(object_id)
1091            .select_only()
1092            .column(table::Column::TableType)
1093            .into_tuple::<TableType>()
1094            .one(db)
1095            .await?
1096            .ok_or_else(|| MetaError::catalog_id_not_found("table", object_id))?;
1097        user_default_privilege::Column::ForMaterializedView
1098            .eq(table_type == TableType::MaterializedView)
1099    } else {
1100        user_default_privilege::Column::ForMaterializedView.eq(false)
1101    };
1102    let schema_filter = if let Some(schema_id) = &object.schema_id {
1103        user_default_privilege::Column::SchemaId.eq(*schema_id)
1104    } else {
1105        user_default_privilege::Column::SchemaId.is_null()
1106    };
1107
1108    let default_privileges: Vec<(UserId, UserId, Action, bool)> = UserDefaultPrivilege::find()
1109        .select_only()
1110        .columns([
1111            user_default_privilege::Column::Grantee,
1112            user_default_privilege::Column::GrantedBy,
1113            user_default_privilege::Column::Action,
1114            user_default_privilege::Column::WithGrantOption,
1115        ])
1116        .filter(
1117            user_default_privilege::Column::DatabaseId
1118                .eq(object.database_id.unwrap())
1119                .and(schema_filter)
1120                .and(user_default_privilege::Column::UserId.eq(object.owner_id))
1121                .and(user_default_privilege::Column::ObjectType.eq(object.obj_type))
1122                .and(for_mview_filter),
1123        )
1124        .into_tuple()
1125        .all(db)
1126        .await?;
1127    if default_privileges.is_empty() {
1128        return Ok(vec![]);
1129    }
1130
1131    let updated_user_ids = default_privileges
1132        .iter()
1133        .map(|(grantee, _, _, _)| *grantee)
1134        .collect::<HashSet<_>>();
1135
1136    let internal_table_ids = get_internal_tables_by_id(object_id, db).await?;
1137
1138    for (grantee, granted_by, action, with_grant_option) in default_privileges {
1139        UserPrivilege::insert(user_privilege::ActiveModel {
1140            user_id: Set(grantee),
1141            oid: Set(object_id),
1142            granted_by: Set(granted_by),
1143            action: Set(action),
1144            with_grant_option: Set(with_grant_option),
1145            ..Default::default()
1146        })
1147        .exec(db)
1148        .await?;
1149        if action == Action::Select && !internal_table_ids.is_empty() {
1150            // Grant SELECT privilege for internal tables if the action is SELECT.
1151            for internal_table_id in &internal_table_ids {
1152                UserPrivilege::insert(user_privilege::ActiveModel {
1153                    user_id: Set(grantee),
1154                    oid: Set(*internal_table_id as _),
1155                    granted_by: Set(granted_by),
1156                    action: Set(Action::Select),
1157                    with_grant_option: Set(with_grant_option),
1158                    ..Default::default()
1159                })
1160                .exec(db)
1161                .await?;
1162            }
1163        }
1164    }
1165
1166    let updated_user_infos = list_user_info_by_ids(updated_user_ids, db).await?;
1167    Ok(updated_user_infos)
1168}
1169
1170// todo: remove it after migrated to sql backend.
1171pub fn extract_grant_obj_id(object: &PbGrantObject) -> ObjectId {
1172    match object {
1173        PbGrantObject::DatabaseId(id)
1174        | PbGrantObject::SchemaId(id)
1175        | PbGrantObject::TableId(id)
1176        | PbGrantObject::SourceId(id)
1177        | PbGrantObject::SinkId(id)
1178        | PbGrantObject::ViewId(id)
1179        | PbGrantObject::FunctionId(id)
1180        | PbGrantObject::SubscriptionId(id)
1181        | PbGrantObject::ConnectionId(id)
1182        | PbGrantObject::SecretId(id) => *id as _,
1183    }
1184}
1185
1186pub async fn insert_fragment_relations(
1187    db: &impl ConnectionTrait,
1188    downstream_fragment_relations: &FragmentDownstreamRelation,
1189) -> MetaResult<()> {
1190    for (upstream_fragment_id, downstreams) in downstream_fragment_relations {
1191        for downstream in downstreams {
1192            let relation = fragment_relation::Model {
1193                source_fragment_id: *upstream_fragment_id as _,
1194                target_fragment_id: downstream.downstream_fragment_id as _,
1195                dispatcher_type: downstream.dispatcher_type,
1196                dist_key_indices: downstream
1197                    .dist_key_indices
1198                    .iter()
1199                    .map(|idx| *idx as i32)
1200                    .collect_vec()
1201                    .into(),
1202                output_indices: downstream
1203                    .output_mapping
1204                    .indices
1205                    .iter()
1206                    .map(|idx| *idx as i32)
1207                    .collect_vec()
1208                    .into(),
1209                output_type_mapping: Some(downstream.output_mapping.types.clone().into()),
1210            };
1211            FragmentRelation::insert(relation.into_active_model())
1212                .exec(db)
1213                .await?;
1214        }
1215    }
1216    Ok(())
1217}
1218
1219pub async fn get_fragment_actor_dispatchers<C>(
1220    db: &C,
1221    fragment_ids: Vec<FragmentId>,
1222) -> MetaResult<FragmentActorDispatchers>
1223where
1224    C: ConnectionTrait,
1225{
1226    type FragmentActorInfo = (
1227        DistributionType,
1228        Arc<HashMap<crate::model::ActorId, Option<Bitmap>>>,
1229    );
1230    let mut fragment_actor_cache: HashMap<FragmentId, FragmentActorInfo> = HashMap::new();
1231    let get_fragment_actors = |fragment_id: FragmentId| async move {
1232        let result: MetaResult<FragmentActorInfo> = try {
1233            let mut fragment_actors = Fragment::find_by_id(fragment_id)
1234                .find_with_related(Actor)
1235                .filter(actor::Column::Status.eq(ActorStatus::Running))
1236                .all(db)
1237                .await?;
1238            if fragment_actors.is_empty() {
1239                return Err(anyhow!("failed to find fragment: {}", fragment_id).into());
1240            }
1241            assert_eq!(
1242                fragment_actors.len(),
1243                1,
1244                "find multiple fragment {:?}",
1245                fragment_actors
1246            );
1247            let (fragment, actors) = fragment_actors.pop().unwrap();
1248            (
1249                fragment.distribution_type,
1250                Arc::new(
1251                    actors
1252                        .into_iter()
1253                        .map(|actor| {
1254                            (
1255                                actor.actor_id as _,
1256                                actor
1257                                    .vnode_bitmap
1258                                    .map(|bitmap| Bitmap::from(bitmap.to_protobuf())),
1259                            )
1260                        })
1261                        .collect(),
1262                ),
1263            )
1264        };
1265        result
1266    };
1267    let fragment_relations = FragmentRelation::find()
1268        .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids))
1269        .all(db)
1270        .await?;
1271
1272    let mut actor_dispatchers_map: HashMap<_, HashMap<_, Vec<_>>> = HashMap::new();
1273    for fragment_relation::Model {
1274        source_fragment_id,
1275        target_fragment_id,
1276        dispatcher_type,
1277        dist_key_indices,
1278        output_indices,
1279        output_type_mapping,
1280    } in fragment_relations
1281    {
1282        let (source_fragment_distribution, source_fragment_actors) = {
1283            let (distribution, actors) = {
1284                match fragment_actor_cache.entry(source_fragment_id) {
1285                    Entry::Occupied(entry) => entry.into_mut(),
1286                    Entry::Vacant(entry) => {
1287                        entry.insert(get_fragment_actors(source_fragment_id).await?)
1288                    }
1289                }
1290            };
1291            (*distribution, actors.clone())
1292        };
1293        let (target_fragment_distribution, target_fragment_actors) = {
1294            let (distribution, actors) = {
1295                match fragment_actor_cache.entry(target_fragment_id) {
1296                    Entry::Occupied(entry) => entry.into_mut(),
1297                    Entry::Vacant(entry) => {
1298                        entry.insert(get_fragment_actors(target_fragment_id).await?)
1299                    }
1300                }
1301            };
1302            (*distribution, actors.clone())
1303        };
1304        let output_mapping = PbDispatchOutputMapping {
1305            indices: output_indices.into_u32_array(),
1306            types: output_type_mapping.unwrap_or_default().to_protobuf(),
1307        };
1308        let dispatchers = compose_dispatchers(
1309            source_fragment_distribution,
1310            &source_fragment_actors,
1311            target_fragment_id as _,
1312            target_fragment_distribution,
1313            &target_fragment_actors,
1314            dispatcher_type,
1315            dist_key_indices.into_u32_array(),
1316            output_mapping,
1317        );
1318        let actor_dispatchers_map = actor_dispatchers_map
1319            .entry(source_fragment_id as _)
1320            .or_default();
1321        for (actor_id, dispatchers) in dispatchers {
1322            actor_dispatchers_map
1323                .entry(actor_id as _)
1324                .or_default()
1325                .push(dispatchers);
1326        }
1327    }
1328    Ok(actor_dispatchers_map)
1329}
1330
1331pub fn compose_dispatchers(
1332    source_fragment_distribution: DistributionType,
1333    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1334    target_fragment_id: crate::model::FragmentId,
1335    target_fragment_distribution: DistributionType,
1336    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1337    dispatcher_type: DispatcherType,
1338    dist_key_indices: Vec<u32>,
1339    output_mapping: PbDispatchOutputMapping,
1340) -> HashMap<crate::model::ActorId, PbDispatcher> {
1341    match dispatcher_type {
1342        DispatcherType::Hash => {
1343            let dispatcher = PbDispatcher {
1344                r#type: PbDispatcherType::from(dispatcher_type) as _,
1345                dist_key_indices: dist_key_indices.clone(),
1346                output_mapping: output_mapping.into(),
1347                hash_mapping: Some(
1348                    ActorMapping::from_bitmaps(
1349                        &target_fragment_actors
1350                            .iter()
1351                            .map(|(actor_id, bitmap)| {
1352                                (
1353                                    *actor_id as _,
1354                                    bitmap
1355                                        .clone()
1356                                        .expect("downstream hash dispatch must have distribution"),
1357                                )
1358                            })
1359                            .collect(),
1360                    )
1361                    .to_protobuf(),
1362                ),
1363                dispatcher_id: target_fragment_id as _,
1364                downstream_actor_id: target_fragment_actors
1365                    .keys()
1366                    .map(|actor_id| *actor_id as _)
1367                    .collect(),
1368            };
1369            source_fragment_actors
1370                .keys()
1371                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1372                .collect()
1373        }
1374        DispatcherType::Broadcast | DispatcherType::Simple => {
1375            let dispatcher = PbDispatcher {
1376                r#type: PbDispatcherType::from(dispatcher_type) as _,
1377                dist_key_indices: dist_key_indices.clone(),
1378                output_mapping: output_mapping.into(),
1379                hash_mapping: None,
1380                dispatcher_id: target_fragment_id as _,
1381                downstream_actor_id: target_fragment_actors
1382                    .keys()
1383                    .map(|actor_id| *actor_id as _)
1384                    .collect(),
1385            };
1386            source_fragment_actors
1387                .keys()
1388                .map(|source_actor_id| (*source_actor_id, dispatcher.clone()))
1389                .collect()
1390        }
1391        DispatcherType::NoShuffle => resolve_no_shuffle_actor_dispatcher(
1392            source_fragment_distribution,
1393            source_fragment_actors,
1394            target_fragment_distribution,
1395            target_fragment_actors,
1396        )
1397        .into_iter()
1398        .map(|(upstream_actor_id, downstream_actor_id)| {
1399            (
1400                upstream_actor_id,
1401                PbDispatcher {
1402                    r#type: PbDispatcherType::NoShuffle as _,
1403                    dist_key_indices: dist_key_indices.clone(),
1404                    output_mapping: output_mapping.clone().into(),
1405                    hash_mapping: None,
1406                    dispatcher_id: target_fragment_id as _,
1407                    downstream_actor_id: vec![downstream_actor_id as _],
1408                },
1409            )
1410        })
1411        .collect(),
1412    }
1413}
1414
1415/// return (`upstream_actor_id` -> `downstream_actor_id`)
1416pub fn resolve_no_shuffle_actor_dispatcher(
1417    source_fragment_distribution: DistributionType,
1418    source_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1419    target_fragment_distribution: DistributionType,
1420    target_fragment_actors: &HashMap<crate::model::ActorId, Option<Bitmap>>,
1421) -> Vec<(crate::model::ActorId, crate::model::ActorId)> {
1422    assert_eq!(source_fragment_distribution, target_fragment_distribution);
1423    assert_eq!(
1424        source_fragment_actors.len(),
1425        target_fragment_actors.len(),
1426        "no-shuffle should have equal upstream downstream actor count: {:?} {:?}",
1427        source_fragment_actors,
1428        target_fragment_actors
1429    );
1430    match source_fragment_distribution {
1431        DistributionType::Single => {
1432            let assert_singleton = |bitmap: &Option<Bitmap>| {
1433                assert!(
1434                    bitmap.as_ref().map(|bitmap| bitmap.all()).unwrap_or(true),
1435                    "not singleton: {:?}",
1436                    bitmap
1437                );
1438            };
1439            assert_eq!(
1440                source_fragment_actors.len(),
1441                1,
1442                "singleton distribution actor count not 1: {:?}",
1443                source_fragment_distribution
1444            );
1445            assert_eq!(
1446                target_fragment_actors.len(),
1447                1,
1448                "singleton distribution actor count not 1: {:?}",
1449                target_fragment_distribution
1450            );
1451            let (source_actor_id, bitmap) = source_fragment_actors.iter().next().unwrap();
1452            assert_singleton(bitmap);
1453            let (target_actor_id, bitmap) = target_fragment_actors.iter().next().unwrap();
1454            assert_singleton(bitmap);
1455            vec![(*source_actor_id, *target_actor_id)]
1456        }
1457        DistributionType::Hash => {
1458            let mut target_fragment_actor_index: HashMap<_, _> = target_fragment_actors
1459                .iter()
1460                .map(|(actor_id, bitmap)| {
1461                    let bitmap = bitmap
1462                        .as_ref()
1463                        .expect("hash distribution should have bitmap");
1464                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1465                    (first_vnode, (*actor_id, bitmap))
1466                })
1467                .collect();
1468            source_fragment_actors
1469                .iter()
1470                .map(|(source_actor_id, bitmap)| {
1471                    let bitmap = bitmap
1472                        .as_ref()
1473                        .expect("hash distribution should have bitmap");
1474                    let first_vnode = bitmap.iter_vnodes().next().expect("non-empty bitmap");
1475                    let (target_actor_id, target_bitmap) =
1476                        target_fragment_actor_index.remove(&first_vnode).unwrap_or_else(|| {
1477                            panic!(
1478                                "cannot find matched target actor: {} {:?} {:?} {:?}",
1479                                source_actor_id,
1480                                first_vnode,
1481                                source_fragment_actors,
1482                                target_fragment_actors
1483                            );
1484                        });
1485                    assert_eq!(
1486                        bitmap,
1487                        target_bitmap,
1488                        "cannot find matched target actor due to bitmap mismatch: {} {:?} {:?} {:?}",
1489                        source_actor_id,
1490                        first_vnode,
1491                        source_fragment_actors,
1492                        target_fragment_actors
1493                    );
1494                    (*source_actor_id, target_actor_id)
1495                }).collect()
1496        }
1497    }
1498}
1499
1500/// `get_fragment_mappings` returns the fragment vnode mappings of the given job.
1501pub async fn get_fragment_mappings<C>(
1502    db: &C,
1503    job_id: ObjectId,
1504) -> MetaResult<Vec<PbFragmentWorkerSlotMapping>>
1505where
1506    C: ConnectionTrait,
1507{
1508    let job_actors: Vec<(
1509        FragmentId,
1510        DistributionType,
1511        ActorId,
1512        Option<VnodeBitmap>,
1513        WorkerId,
1514        ActorStatus,
1515    )> = Actor::find()
1516        .select_only()
1517        .columns([
1518            fragment::Column::FragmentId,
1519            fragment::Column::DistributionType,
1520        ])
1521        .columns([
1522            actor::Column::ActorId,
1523            actor::Column::VnodeBitmap,
1524            actor::Column::WorkerId,
1525            actor::Column::Status,
1526        ])
1527        .join(JoinType::InnerJoin, actor::Relation::Fragment.def())
1528        .filter(fragment::Column::JobId.eq(job_id))
1529        .into_tuple()
1530        .all(db)
1531        .await?;
1532
1533    Ok(rebuild_fragment_mapping_from_actors(job_actors))
1534}
1535
1536pub fn rebuild_fragment_mapping(fragment: &SharedFragmentInfo) -> PbFragmentWorkerSlotMapping {
1537    let fragment_worker_slot_mapping = match fragment.distribution_type {
1538        DistributionType::Single => {
1539            let actor = fragment.actors.values().exactly_one().unwrap();
1540            WorkerSlotMapping::new_single(WorkerSlotId::new(actor.worker_id as _, 0))
1541        }
1542        DistributionType::Hash => {
1543            let actor_bitmaps: HashMap<_, _> = fragment
1544                .actors
1545                .iter()
1546                .map(|(actor_id, actor_info)| {
1547                    let vnode_bitmap = actor_info
1548                        .vnode_bitmap
1549                        .as_ref()
1550                        .cloned()
1551                        .expect("actor bitmap shouldn't be none in hash fragment");
1552
1553                    (*actor_id as hash::ActorId, vnode_bitmap)
1554                })
1555                .collect();
1556
1557            let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
1558
1559            let actor_locations = fragment
1560                .actors
1561                .iter()
1562                .map(|(actor_id, actor_info)| {
1563                    (*actor_id as hash::ActorId, actor_info.worker_id as u32)
1564                })
1565                .collect();
1566
1567            actor_mapping.to_worker_slot(&actor_locations)
1568        }
1569    };
1570
1571    PbFragmentWorkerSlotMapping {
1572        fragment_id: fragment.fragment_id,
1573        mapping: Some(fragment_worker_slot_mapping.to_protobuf()),
1574    }
1575}
1576
1577pub fn rebuild_fragment_mapping_from_actors(
1578    job_actors: Vec<(
1579        FragmentId,
1580        DistributionType,
1581        ActorId,
1582        Option<VnodeBitmap>,
1583        WorkerId,
1584        ActorStatus,
1585    )>,
1586) -> Vec<FragmentWorkerSlotMapping> {
1587    let mut all_actor_locations = HashMap::new();
1588    let mut actor_bitmaps = HashMap::new();
1589    let mut fragment_actors = HashMap::new();
1590    let mut fragment_dist = HashMap::new();
1591
1592    for (fragment_id, dist, actor_id, bitmap, worker_id, actor_status) in job_actors {
1593        if actor_status == ActorStatus::Inactive {
1594            continue;
1595        }
1596
1597        all_actor_locations
1598            .entry(fragment_id)
1599            .or_insert(HashMap::new())
1600            .insert(actor_id as hash::ActorId, worker_id as u32);
1601        actor_bitmaps.insert(actor_id, bitmap);
1602        fragment_actors
1603            .entry(fragment_id)
1604            .or_insert_with(Vec::new)
1605            .push(actor_id);
1606        fragment_dist.insert(fragment_id, dist);
1607    }
1608
1609    let mut result = vec![];
1610    for (fragment_id, dist) in fragment_dist {
1611        let mut actor_locations = all_actor_locations.remove(&fragment_id).unwrap();
1612        let fragment_worker_slot_mapping = match dist {
1613            DistributionType::Single => {
1614                let actor = fragment_actors
1615                    .remove(&fragment_id)
1616                    .unwrap()
1617                    .into_iter()
1618                    .exactly_one()
1619                    .unwrap() as hash::ActorId;
1620                let actor_location = actor_locations.remove(&actor).unwrap();
1621
1622                WorkerSlotMapping::new_single(WorkerSlotId::new(actor_location, 0))
1623            }
1624            DistributionType::Hash => {
1625                let actors = fragment_actors.remove(&fragment_id).unwrap();
1626
1627                let all_actor_bitmaps: HashMap<_, _> = actors
1628                    .iter()
1629                    .map(|actor_id| {
1630                        let vnode_bitmap = actor_bitmaps
1631                            .remove(actor_id)
1632                            .flatten()
1633                            .expect("actor bitmap shouldn't be none in hash fragment");
1634
1635                        let bitmap = Bitmap::from(&vnode_bitmap.to_protobuf());
1636                        (*actor_id as hash::ActorId, bitmap)
1637                    })
1638                    .collect();
1639
1640                let actor_mapping = ActorMapping::from_bitmaps(&all_actor_bitmaps);
1641
1642                actor_mapping.to_worker_slot(&actor_locations)
1643            }
1644        };
1645
1646        result.push(PbFragmentWorkerSlotMapping {
1647            fragment_id: fragment_id as u32,
1648            mapping: Some(fragment_worker_slot_mapping.to_protobuf()),
1649        })
1650    }
1651    result
1652}
1653
1654/// `get_fragment_actor_ids` returns the fragment actor ids of the given fragments.
1655pub async fn get_fragment_actor_ids<C>(
1656    db: &C,
1657    fragment_ids: Vec<FragmentId>,
1658) -> MetaResult<HashMap<FragmentId, Vec<ActorId>>>
1659where
1660    C: ConnectionTrait,
1661{
1662    let fragment_actors: Vec<(FragmentId, ActorId)> = Actor::find()
1663        .select_only()
1664        .columns([actor::Column::FragmentId, actor::Column::ActorId])
1665        .filter(actor::Column::FragmentId.is_in(fragment_ids))
1666        .into_tuple()
1667        .all(db)
1668        .await?;
1669
1670    Ok(fragment_actors.into_iter().into_group_map())
1671}
1672
1673/// For the given streaming jobs, returns
1674/// - All source fragments
1675/// - All actors
1676/// - All fragments
1677pub async fn get_fragments_for_jobs<C>(
1678    db: &C,
1679    streaming_jobs: Vec<ObjectId>,
1680) -> MetaResult<(
1681    HashMap<SourceId, BTreeSet<FragmentId>>,
1682    HashSet<ActorId>,
1683    HashSet<FragmentId>,
1684)>
1685where
1686    C: ConnectionTrait,
1687{
1688    if streaming_jobs.is_empty() {
1689        return Ok((HashMap::default(), HashSet::default(), HashSet::default()));
1690    }
1691
1692    let fragments: Vec<(FragmentId, i32, StreamNode)> = Fragment::find()
1693        .select_only()
1694        .columns([
1695            fragment::Column::FragmentId,
1696            fragment::Column::FragmentTypeMask,
1697            fragment::Column::StreamNode,
1698        ])
1699        .filter(fragment::Column::JobId.is_in(streaming_jobs))
1700        .into_tuple()
1701        .all(db)
1702        .await?;
1703    let actors: Vec<ActorId> = Actor::find()
1704        .select_only()
1705        .column(actor::Column::ActorId)
1706        .filter(
1707            actor::Column::FragmentId.is_in(fragments.iter().map(|(id, _, _)| *id).collect_vec()),
1708        )
1709        .into_tuple()
1710        .all(db)
1711        .await?;
1712
1713    let fragment_ids = fragments
1714        .iter()
1715        .map(|(fragment_id, _, _)| *fragment_id)
1716        .collect();
1717
1718    let mut source_fragment_ids: HashMap<SourceId, BTreeSet<FragmentId>> = HashMap::new();
1719    for (fragment_id, mask, stream_node) in fragments {
1720        if !FragmentTypeMask::from(mask).contains(FragmentTypeFlag::Source) {
1721            continue;
1722        }
1723        if let Some(source_id) = stream_node.to_protobuf().find_stream_source() {
1724            source_fragment_ids
1725                .entry(source_id as _)
1726                .or_default()
1727                .insert(fragment_id);
1728        }
1729    }
1730
1731    Ok((
1732        source_fragment_ids,
1733        actors.into_iter().collect(),
1734        fragment_ids,
1735    ))
1736}
1737
1738/// Build a object group for notifying the deletion of the given objects.
1739///
1740/// Note that only id fields are filled in the object info, as the arguments are partial objects.
1741/// As a result, the returned notification info should only be used for deletion.
1742pub(crate) fn build_object_group_for_delete(
1743    partial_objects: Vec<PartialObject>,
1744) -> NotificationInfo {
1745    let mut objects = vec![];
1746    for obj in partial_objects {
1747        match obj.obj_type {
1748            ObjectType::Database => objects.push(PbObject {
1749                object_info: Some(PbObjectInfo::Database(PbDatabase {
1750                    id: obj.oid as _,
1751                    ..Default::default()
1752                })),
1753            }),
1754            ObjectType::Schema => objects.push(PbObject {
1755                object_info: Some(PbObjectInfo::Schema(PbSchema {
1756                    id: obj.oid as _,
1757                    database_id: obj.database_id.unwrap() as _,
1758                    ..Default::default()
1759                })),
1760            }),
1761            ObjectType::Table => objects.push(PbObject {
1762                object_info: Some(PbObjectInfo::Table(PbTable {
1763                    id: obj.oid as _,
1764                    schema_id: obj.schema_id.unwrap() as _,
1765                    database_id: obj.database_id.unwrap() as _,
1766                    ..Default::default()
1767                })),
1768            }),
1769            ObjectType::Source => objects.push(PbObject {
1770                object_info: Some(PbObjectInfo::Source(PbSource {
1771                    id: obj.oid as _,
1772                    schema_id: obj.schema_id.unwrap() as _,
1773                    database_id: obj.database_id.unwrap() as _,
1774                    ..Default::default()
1775                })),
1776            }),
1777            ObjectType::Sink => objects.push(PbObject {
1778                object_info: Some(PbObjectInfo::Sink(PbSink {
1779                    id: obj.oid as _,
1780                    schema_id: obj.schema_id.unwrap() as _,
1781                    database_id: obj.database_id.unwrap() as _,
1782                    ..Default::default()
1783                })),
1784            }),
1785            ObjectType::Subscription => objects.push(PbObject {
1786                object_info: Some(PbObjectInfo::Subscription(PbSubscription {
1787                    id: obj.oid as _,
1788                    schema_id: obj.schema_id.unwrap() as _,
1789                    database_id: obj.database_id.unwrap() as _,
1790                    ..Default::default()
1791                })),
1792            }),
1793            ObjectType::View => objects.push(PbObject {
1794                object_info: Some(PbObjectInfo::View(PbView {
1795                    id: obj.oid as _,
1796                    schema_id: obj.schema_id.unwrap() as _,
1797                    database_id: obj.database_id.unwrap() as _,
1798                    ..Default::default()
1799                })),
1800            }),
1801            ObjectType::Index => {
1802                objects.push(PbObject {
1803                    object_info: Some(PbObjectInfo::Index(PbIndex {
1804                        id: obj.oid as _,
1805                        schema_id: obj.schema_id.unwrap() as _,
1806                        database_id: obj.database_id.unwrap() as _,
1807                        ..Default::default()
1808                    })),
1809                });
1810                objects.push(PbObject {
1811                    object_info: Some(PbObjectInfo::Table(PbTable {
1812                        id: obj.oid as _,
1813                        schema_id: obj.schema_id.unwrap() as _,
1814                        database_id: obj.database_id.unwrap() as _,
1815                        ..Default::default()
1816                    })),
1817                });
1818            }
1819            ObjectType::Function => objects.push(PbObject {
1820                object_info: Some(PbObjectInfo::Function(PbFunction {
1821                    id: obj.oid as _,
1822                    schema_id: obj.schema_id.unwrap() as _,
1823                    database_id: obj.database_id.unwrap() as _,
1824                    ..Default::default()
1825                })),
1826            }),
1827            ObjectType::Connection => objects.push(PbObject {
1828                object_info: Some(PbObjectInfo::Connection(PbConnection {
1829                    id: obj.oid as _,
1830                    schema_id: obj.schema_id.unwrap() as _,
1831                    database_id: obj.database_id.unwrap() as _,
1832                    ..Default::default()
1833                })),
1834            }),
1835            ObjectType::Secret => objects.push(PbObject {
1836                object_info: Some(PbObjectInfo::Secret(PbSecret {
1837                    id: obj.oid as _,
1838                    schema_id: obj.schema_id.unwrap() as _,
1839                    database_id: obj.database_id.unwrap() as _,
1840                    ..Default::default()
1841                })),
1842            }),
1843        }
1844    }
1845    NotificationInfo::ObjectGroup(PbObjectGroup { objects })
1846}
1847
1848pub fn extract_external_table_name_from_definition(table_definition: &str) -> Option<String> {
1849    let [mut definition]: [_; 1] = Parser::parse_sql(table_definition)
1850        .context("unable to parse table definition")
1851        .inspect_err(|e| {
1852            tracing::error!(
1853                target: "auto_schema_change",
1854                error = %e.as_report(),
1855                "failed to parse table definition")
1856        })
1857        .unwrap()
1858        .try_into()
1859        .unwrap();
1860    if let SqlStatement::CreateTable { cdc_table_info, .. } = &mut definition {
1861        cdc_table_info
1862            .clone()
1863            .map(|cdc_table_info| cdc_table_info.external_table_name)
1864    } else {
1865        None
1866    }
1867}
1868
1869/// `rename_relation` renames the target relation and its definition,
1870/// it commits the changes to the transaction and returns the updated relations and the old name.
1871pub async fn rename_relation(
1872    txn: &DatabaseTransaction,
1873    object_type: ObjectType,
1874    object_id: ObjectId,
1875    object_name: &str,
1876) -> MetaResult<(Vec<PbObject>, String)> {
1877    use sea_orm::ActiveModelTrait;
1878
1879    use crate::controller::rename::alter_relation_rename;
1880
1881    let mut to_update_relations = vec![];
1882    // rename relation.
1883    macro_rules! rename_relation {
1884        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
1885            let (mut relation, obj) = $entity::find_by_id($object_id)
1886                .find_also_related(Object)
1887                .one(txn)
1888                .await?
1889                .unwrap();
1890            let obj = obj.unwrap();
1891            let old_name = relation.name.clone();
1892            relation.name = object_name.into();
1893            if obj.obj_type != ObjectType::View {
1894                relation.definition = alter_relation_rename(&relation.definition, object_name);
1895            }
1896            let active_model = $table::ActiveModel {
1897                $identity: Set(relation.$identity),
1898                name: Set(object_name.into()),
1899                definition: Set(relation.definition.clone()),
1900                ..Default::default()
1901            };
1902            active_model.update(txn).await?;
1903            to_update_relations.push(PbObject {
1904                object_info: Some(PbObjectInfo::$entity(ObjectModel(relation, obj).into())),
1905            });
1906            old_name
1907        }};
1908    }
1909    // TODO: check is there any thing to change for shared source?
1910    let old_name = match object_type {
1911        ObjectType::Table => {
1912            let associated_source_id: Option<SourceId> = Source::find()
1913                .select_only()
1914                .column(source::Column::SourceId)
1915                .filter(source::Column::OptionalAssociatedTableId.eq(object_id))
1916                .into_tuple()
1917                .one(txn)
1918                .await?;
1919            if let Some(source_id) = associated_source_id {
1920                rename_relation!(Source, source, source_id, source_id);
1921            }
1922            rename_relation!(Table, table, table_id, object_id)
1923        }
1924        ObjectType::Source => rename_relation!(Source, source, source_id, object_id),
1925        ObjectType::Sink => rename_relation!(Sink, sink, sink_id, object_id),
1926        ObjectType::Subscription => {
1927            rename_relation!(Subscription, subscription, subscription_id, object_id)
1928        }
1929        ObjectType::View => rename_relation!(View, view, view_id, object_id),
1930        ObjectType::Index => {
1931            let (mut index, obj) = Index::find_by_id(object_id)
1932                .find_also_related(Object)
1933                .one(txn)
1934                .await?
1935                .unwrap();
1936            index.name = object_name.into();
1937            let index_table_id = index.index_table_id;
1938            let old_name = rename_relation!(Table, table, table_id, index_table_id);
1939
1940            // the name of index and its associated table is the same.
1941            let active_model = index::ActiveModel {
1942                index_id: sea_orm::ActiveValue::Set(index.index_id),
1943                name: sea_orm::ActiveValue::Set(object_name.into()),
1944                ..Default::default()
1945            };
1946            active_model.update(txn).await?;
1947            to_update_relations.push(PbObject {
1948                object_info: Some(PbObjectInfo::Index(ObjectModel(index, obj.unwrap()).into())),
1949            });
1950            old_name
1951        }
1952        _ => unreachable!("only relation name can be altered."),
1953    };
1954
1955    Ok((to_update_relations, old_name))
1956}
1957
1958pub async fn get_database_resource_group<C>(txn: &C, database_id: ObjectId) -> MetaResult<String>
1959where
1960    C: ConnectionTrait,
1961{
1962    let database_resource_group: Option<String> = Database::find_by_id(database_id)
1963        .select_only()
1964        .column(database::Column::ResourceGroup)
1965        .into_tuple()
1966        .one(txn)
1967        .await?
1968        .ok_or_else(|| MetaError::catalog_id_not_found("database", database_id))?;
1969
1970    Ok(database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned()))
1971}
1972
1973pub async fn get_existing_job_resource_group<C>(
1974    txn: &C,
1975    streaming_job_id: ObjectId,
1976) -> MetaResult<String>
1977where
1978    C: ConnectionTrait,
1979{
1980    let (job_specific_resource_group, database_resource_group): (Option<String>, Option<String>) =
1981        StreamingJob::find_by_id(streaming_job_id)
1982            .select_only()
1983            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
1984            .join(JoinType::InnerJoin, object::Relation::Database2.def())
1985            .column(streaming_job::Column::SpecificResourceGroup)
1986            .column(database::Column::ResourceGroup)
1987            .into_tuple()
1988            .one(txn)
1989            .await?
1990            .ok_or_else(|| MetaError::catalog_id_not_found("streaming job", streaming_job_id))?;
1991
1992    Ok(job_specific_resource_group.unwrap_or_else(|| {
1993        database_resource_group.unwrap_or_else(|| DEFAULT_RESOURCE_GROUP.to_owned())
1994    }))
1995}
1996
1997pub fn filter_workers_by_resource_group(
1998    workers: &HashMap<u32, WorkerNode>,
1999    resource_group: &str,
2000) -> BTreeSet<WorkerId> {
2001    workers
2002        .iter()
2003        .filter(|&(_, worker)| {
2004            worker
2005                .resource_group()
2006                .map(|node_label| node_label.as_str() == resource_group)
2007                .unwrap_or(false)
2008        })
2009        .map(|(id, _)| (*id as WorkerId))
2010        .collect()
2011}
2012
2013/// `rename_relation_refer` updates the definition of relations that refer to the target one,
2014/// it commits the changes to the transaction and returns all the updated relations.
2015pub async fn rename_relation_refer(
2016    txn: &DatabaseTransaction,
2017    object_type: ObjectType,
2018    object_id: ObjectId,
2019    object_name: &str,
2020    old_name: &str,
2021) -> MetaResult<Vec<PbObject>> {
2022    use sea_orm::ActiveModelTrait;
2023
2024    use crate::controller::rename::alter_relation_rename_refs;
2025
2026    let mut to_update_relations = vec![];
2027    macro_rules! rename_relation_ref {
2028        ($entity:ident, $table:ident, $identity:ident, $object_id:expr) => {{
2029            let (mut relation, obj) = $entity::find_by_id($object_id)
2030                .find_also_related(Object)
2031                .one(txn)
2032                .await?
2033                .unwrap();
2034            relation.definition =
2035                alter_relation_rename_refs(&relation.definition, old_name, object_name);
2036            let active_model = $table::ActiveModel {
2037                $identity: Set(relation.$identity),
2038                definition: Set(relation.definition.clone()),
2039                ..Default::default()
2040            };
2041            active_model.update(txn).await?;
2042            to_update_relations.push(PbObject {
2043                object_info: Some(PbObjectInfo::$entity(
2044                    ObjectModel(relation, obj.unwrap()).into(),
2045                )),
2046            });
2047        }};
2048    }
2049    let mut objs = get_referring_objects(object_id, txn).await?;
2050    if object_type == ObjectType::Table {
2051        let incoming_sinks: I32Array = Table::find_by_id(object_id)
2052            .select_only()
2053            .column(table::Column::IncomingSinks)
2054            .into_tuple()
2055            .one(txn)
2056            .await?
2057            .ok_or_else(|| MetaError::catalog_id_not_found("table", object_id))?;
2058
2059        objs.extend(
2060            incoming_sinks
2061                .into_inner()
2062                .into_iter()
2063                .map(|id| PartialObject {
2064                    oid: id,
2065                    obj_type: ObjectType::Sink,
2066                    schema_id: None,
2067                    database_id: None,
2068                }),
2069        );
2070    }
2071
2072    for obj in objs {
2073        match obj.obj_type {
2074            ObjectType::Table => rename_relation_ref!(Table, table, table_id, obj.oid),
2075            ObjectType::Sink => rename_relation_ref!(Sink, sink, sink_id, obj.oid),
2076            ObjectType::Subscription => {
2077                rename_relation_ref!(Subscription, subscription, subscription_id, obj.oid)
2078            }
2079            ObjectType::View => rename_relation_ref!(View, view, view_id, obj.oid),
2080            ObjectType::Index => {
2081                let index_table_id: Option<TableId> = Index::find_by_id(obj.oid)
2082                    .select_only()
2083                    .column(index::Column::IndexTableId)
2084                    .into_tuple()
2085                    .one(txn)
2086                    .await?;
2087                rename_relation_ref!(Table, table, table_id, index_table_id.unwrap());
2088            }
2089            _ => {
2090                bail!(
2091                    "only the table, sink, subscription, view and index will depend on other objects."
2092                )
2093            }
2094        }
2095    }
2096
2097    Ok(to_update_relations)
2098}
2099
2100/// Validate that subscription can be safely deleted, meeting any of the following conditions:
2101/// 1. The upstream table is not referred to by any cross-db mv.
2102/// 2. After deleting the subscription, the upstream table still has at least one subscription.
2103pub async fn validate_subscription_deletion<C>(txn: &C, subscription_id: ObjectId) -> MetaResult<()>
2104where
2105    C: ConnectionTrait,
2106{
2107    let upstream_table_id: ObjectId = Subscription::find_by_id(subscription_id)
2108        .select_only()
2109        .column(subscription::Column::DependentTableId)
2110        .into_tuple()
2111        .one(txn)
2112        .await?
2113        .ok_or_else(|| MetaError::catalog_id_not_found("subscription", subscription_id))?;
2114
2115    let cnt = Subscription::find()
2116        .filter(subscription::Column::DependentTableId.eq(upstream_table_id))
2117        .count(txn)
2118        .await?;
2119    if cnt > 1 {
2120        // Ensure that at least one subscription is remained for the upstream table
2121        // once the subscription is dropped.
2122        return Ok(());
2123    }
2124
2125    // Ensure that the upstream table is not referred by any cross-db mv.
2126    let obj_alias = Alias::new("o1");
2127    let used_by_alias = Alias::new("o2");
2128    let count = ObjectDependency::find()
2129        .join_as(
2130            JoinType::InnerJoin,
2131            object_dependency::Relation::Object2.def(),
2132            obj_alias.clone(),
2133        )
2134        .join_as(
2135            JoinType::InnerJoin,
2136            object_dependency::Relation::Object1.def(),
2137            used_by_alias.clone(),
2138        )
2139        .filter(
2140            object_dependency::Column::Oid
2141                .eq(upstream_table_id)
2142                .and(object_dependency::Column::UsedBy.ne(subscription_id))
2143                .and(
2144                    Expr::col((obj_alias, object::Column::DatabaseId))
2145                        .ne(Expr::col((used_by_alias, object::Column::DatabaseId))),
2146                ),
2147        )
2148        .count(txn)
2149        .await?;
2150
2151    if count != 0 {
2152        return Err(MetaError::permission_denied(format!(
2153            "Referenced by {} cross-db objects.",
2154            count
2155        )));
2156    }
2157
2158    Ok(())
2159}
2160
2161#[cfg(test)]
2162mod tests {
2163    use super::*;
2164
2165    #[test]
2166    fn test_extract_cdc_table_name() {
2167        let ddl1 = "CREATE TABLE t1 () FROM pg_source TABLE 'public.t1'";
2168        let ddl2 = "CREATE TABLE t2 (v1 int) FROM pg_source TABLE 'mydb.t2'";
2169        assert_eq!(
2170            extract_external_table_name_from_definition(ddl1),
2171            Some("public.t1".into())
2172        );
2173        assert_eq!(
2174            extract_external_table_name_from_definition(ddl2),
2175            Some("mydb.t2".into())
2176        );
2177    }
2178}