risingwave_common/catalog/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod column;
16mod external_table;
17mod internal_table;
18mod physical_table;
19mod schema;
20pub mod test_utils;
21
22use std::fmt::Binary;
23use std::sync::Arc;
24
25pub use column::*;
26pub use external_table::*;
27use futures::stream::BoxStream;
28pub use internal_table::*;
29use parse_display::Display;
30pub use physical_table::*;
31use risingwave_pb::catalog::table::PbEngine;
32use risingwave_pb::catalog::{
33    CreateType as PbCreateType, HandleConflictBehavior as PbHandleConflictBehavior,
34    StreamJobStatus as PbStreamJobStatus,
35};
36use risingwave_pb::plan_common::ColumnDescVersion;
37pub use schema::{Field, FieldDisplay, FieldLike, Schema, test_utils as schema_test_utils};
38use serde::{Deserialize, Serialize};
39
40use crate::array::DataChunk;
41pub use crate::constants::hummock;
42use crate::error::BoxedError;
43
44/// The global version of the catalog.
45pub type CatalogVersion = u64;
46
47/// The version number of the per-table catalog.
48pub type TableVersionId = u64;
49/// The default version ID for a new table.
50pub const INITIAL_TABLE_VERSION_ID: u64 = 0;
51/// The version number of the per-source catalog.
52pub type SourceVersionId = u64;
53/// The default version ID for a new source.
54pub const INITIAL_SOURCE_VERSION_ID: u64 = 0;
55
56pub const DEFAULT_DATABASE_NAME: &str = "dev";
57pub const DEFAULT_SCHEMA_NAME: &str = "public";
58pub const PG_CATALOG_SCHEMA_NAME: &str = "pg_catalog";
59pub const INFORMATION_SCHEMA_SCHEMA_NAME: &str = "information_schema";
60pub const RW_CATALOG_SCHEMA_NAME: &str = "rw_catalog";
61pub const RESERVED_PG_SCHEMA_PREFIX: &str = "pg_";
62pub const DEFAULT_SUPER_USER: &str = "root";
63pub const DEFAULT_SUPER_USER_ID: u32 = 1;
64// This is for compatibility with customized utils for PostgreSQL.
65pub const DEFAULT_SUPER_USER_FOR_PG: &str = "postgres";
66pub const DEFAULT_SUPER_USER_FOR_PG_ID: u32 = 2;
67
68// This is the default superuser for admin, which is used only for cloud control plane.
69pub const DEFAULT_SUPER_USER_FOR_ADMIN: &str = "rwadmin";
70pub const DEFAULT_SUPER_USER_FOR_ADMIN_ID: u32 = 3;
71
72pub const NON_RESERVED_USER_ID: i32 = 11;
73
74pub const MAX_SYS_CATALOG_NUM: i32 = 5000;
75pub const SYS_CATALOG_START_ID: i32 = i32::MAX - MAX_SYS_CATALOG_NUM;
76
77pub const OBJECT_ID_PLACEHOLDER: u32 = u32::MAX - 1;
78
79pub const SYSTEM_SCHEMAS: [&str; 3] = [
80    PG_CATALOG_SCHEMA_NAME,
81    INFORMATION_SCHEMA_SCHEMA_NAME,
82    RW_CATALOG_SCHEMA_NAME,
83];
84pub fn is_system_schema(schema_name: &str) -> bool {
85    SYSTEM_SCHEMAS.contains(&schema_name)
86}
87
88pub fn is_reserved_admin_user(user_name: &str) -> bool {
89    user_name == DEFAULT_SUPER_USER_FOR_ADMIN
90}
91
92pub const RW_RESERVED_COLUMN_NAME_PREFIX: &str = "_rw_";
93
94/// When there is no primary key specified while creating source, will use
95/// the message key as primary key in `BYTEA` type with this name.
96/// Note: the field has version to track, please refer to [`default_key_column_name_version_mapping`]
97pub const DEFAULT_KEY_COLUMN_NAME: &str = "_rw_key";
98
99pub fn default_key_column_name_version_mapping(version: &ColumnDescVersion) -> &str {
100    match version {
101        ColumnDescVersion::Unspecified => DEFAULT_KEY_COLUMN_NAME,
102        _ => DEFAULT_KEY_COLUMN_NAME,
103    }
104}
105
106/// For kafka source, we attach a hidden column [`KAFKA_TIMESTAMP_COLUMN_NAME`] to it, so that we
107/// can limit the timestamp range when querying it directly with batch query. The column type is
108/// [`crate::types::DataType::Timestamptz`]. For more details, please refer to
109/// [this rfc](https://github.com/risingwavelabs/rfcs/pull/20).
110pub const KAFKA_TIMESTAMP_COLUMN_NAME: &str = "_rw_kafka_timestamp";
111
112/// RisingWave iceberg table engine will create the column `_risingwave_iceberg_row_id` in the iceberg table.
113///
114/// Iceberg V3 spec use `_row_id` as a reserved column name for row lineage, so if the table without primary key,
115/// we can't use `_row_id` directly for iceberg, so use `_risingwave_iceberg_row_id` instead.
116pub const RISINGWAVE_ICEBERG_ROW_ID: &str = "_risingwave_iceberg_row_id";
117
118pub const ROW_ID_COLUMN_NAME: &str = "_row_id";
119/// The column ID preserved for the row ID column.
120pub const ROW_ID_COLUMN_ID: ColumnId = ColumnId::new(0);
121
122/// The column ID offset for user-defined columns.
123///
124/// All IDs of user-defined columns must be greater or equal to this value.
125pub const USER_COLUMN_ID_OFFSET: i32 = ROW_ID_COLUMN_ID.next().get_id();
126
127pub const RW_TIMESTAMP_COLUMN_NAME: &str = "_rw_timestamp";
128pub const RW_TIMESTAMP_COLUMN_ID: ColumnId = ColumnId::new(-1);
129
130pub const ICEBERG_SEQUENCE_NUM_COLUMN_NAME: &str = "_iceberg_sequence_number";
131pub const ICEBERG_FILE_PATH_COLUMN_NAME: &str = "_iceberg_file_path";
132pub const ICEBERG_FILE_POS_COLUMN_NAME: &str = "_iceberg_file_pos";
133
134pub const CDC_OFFSET_COLUMN_NAME: &str = "_rw_offset";
135/// The number of columns output by the cdc source job
136/// see [`ColumnCatalog::debezium_cdc_source_cols()`] for details
137pub const CDC_SOURCE_COLUMN_NUM: u32 = 3;
138pub const CDC_TABLE_NAME_COLUMN_NAME: &str = "_rw_table_name";
139
140/// The local system catalog reader in the frontend node.
141pub trait SysCatalogReader: Sync + Send + 'static {
142    /// Reads the data of the system catalog table.
143    fn read_table(&self, table_id: TableId) -> BoxStream<'_, Result<DataChunk, BoxedError>>;
144}
145
146pub type SysCatalogReaderRef = Arc<dyn SysCatalogReader>;
147
148pub type ObjectId = u32;
149
150#[derive(Clone, Debug, Default, Display, Hash, PartialOrd, PartialEq, Eq, Copy)]
151#[display("{database_id}")]
152pub struct DatabaseId {
153    pub database_id: u32,
154}
155
156impl DatabaseId {
157    pub const fn new(database_id: u32) -> Self {
158        DatabaseId { database_id }
159    }
160
161    pub fn placeholder() -> Self {
162        DatabaseId {
163            database_id: OBJECT_ID_PLACEHOLDER,
164        }
165    }
166}
167
168impl From<u32> for DatabaseId {
169    fn from(id: u32) -> Self {
170        Self::new(id)
171    }
172}
173
174impl From<&u32> for DatabaseId {
175    fn from(id: &u32) -> Self {
176        Self::new(*id)
177    }
178}
179
180impl From<DatabaseId> for u32 {
181    fn from(id: DatabaseId) -> Self {
182        id.database_id
183    }
184}
185
186#[derive(Clone, Debug, Default, Display, Hash, PartialOrd, PartialEq, Eq)]
187#[display("{schema_id}")]
188pub struct SchemaId {
189    pub schema_id: u32,
190}
191
192impl SchemaId {
193    pub fn new(schema_id: u32) -> Self {
194        SchemaId { schema_id }
195    }
196
197    pub fn placeholder() -> Self {
198        SchemaId {
199            schema_id: OBJECT_ID_PLACEHOLDER,
200        }
201    }
202}
203
204impl From<u32> for SchemaId {
205    fn from(id: u32) -> Self {
206        Self::new(id)
207    }
208}
209
210impl From<&u32> for SchemaId {
211    fn from(id: &u32) -> Self {
212        Self::new(*id)
213    }
214}
215
216impl From<SchemaId> for u32 {
217    fn from(id: SchemaId) -> Self {
218        id.schema_id
219    }
220}
221
222#[derive(
223    Clone,
224    Copy,
225    Debug,
226    Display,
227    Default,
228    Hash,
229    PartialOrd,
230    PartialEq,
231    Eq,
232    Ord,
233    Serialize,
234    Deserialize,
235)]
236#[display("{table_id}")]
237pub struct TableId {
238    pub table_id: u32,
239}
240
241impl TableId {
242    pub const fn new(table_id: u32) -> Self {
243        TableId { table_id }
244    }
245
246    /// Sometimes the id field is filled later, we use this value for better debugging.
247    pub const fn placeholder() -> Self {
248        TableId {
249            table_id: OBJECT_ID_PLACEHOLDER,
250        }
251    }
252
253    pub fn table_id(&self) -> u32 {
254        self.table_id
255    }
256
257    pub fn is_placeholder(&self) -> bool {
258        self.table_id == OBJECT_ID_PLACEHOLDER
259    }
260}
261
262impl From<u32> for TableId {
263    fn from(id: u32) -> Self {
264        Self::new(id)
265    }
266}
267
268impl From<&u32> for TableId {
269    fn from(id: &u32) -> Self {
270        Self::new(*id)
271    }
272}
273
274impl From<TableId> for u32 {
275    fn from(id: TableId) -> Self {
276        id.table_id
277    }
278}
279
280#[derive(Clone, Debug, PartialEq, Default, Copy)]
281pub struct TableOption {
282    pub retention_seconds: Option<u32>, // second
283}
284
285impl From<&risingwave_pb::hummock::TableOption> for TableOption {
286    fn from(table_option: &risingwave_pb::hummock::TableOption) -> Self {
287        Self {
288            retention_seconds: table_option.retention_seconds,
289        }
290    }
291}
292
293impl From<&TableOption> for risingwave_pb::hummock::TableOption {
294    fn from(table_option: &TableOption) -> Self {
295        Self {
296            retention_seconds: table_option.retention_seconds,
297        }
298    }
299}
300
301impl TableOption {
302    pub fn new(retention_seconds: Option<u32>) -> Self {
303        // now we only support ttl for TableOption
304        TableOption { retention_seconds }
305    }
306}
307
308#[derive(Clone, Copy, Debug, Display, Default, Hash, PartialOrd, PartialEq, Eq)]
309#[display("{index_id}")]
310pub struct IndexId {
311    pub index_id: u32,
312}
313
314impl IndexId {
315    pub const fn new(index_id: u32) -> Self {
316        IndexId { index_id }
317    }
318
319    /// Sometimes the id field is filled later, we use this value for better debugging.
320    pub const fn placeholder() -> Self {
321        IndexId {
322            index_id: OBJECT_ID_PLACEHOLDER,
323        }
324    }
325
326    pub fn index_id(&self) -> u32 {
327        self.index_id
328    }
329}
330
331impl From<u32> for IndexId {
332    fn from(id: u32) -> Self {
333        Self::new(id)
334    }
335}
336impl From<IndexId> for u32 {
337    fn from(id: IndexId) -> Self {
338        id.index_id
339    }
340}
341
342#[derive(Clone, Copy, Debug, Display, Default, Hash, PartialOrd, PartialEq, Eq, Ord)]
343pub struct FunctionId(pub u32);
344
345impl FunctionId {
346    pub const fn new(id: u32) -> Self {
347        FunctionId(id)
348    }
349
350    pub const fn placeholder() -> Self {
351        FunctionId(OBJECT_ID_PLACEHOLDER)
352    }
353
354    pub fn function_id(&self) -> u32 {
355        self.0
356    }
357}
358
359impl From<u32> for FunctionId {
360    fn from(id: u32) -> Self {
361        Self::new(id)
362    }
363}
364
365impl From<&u32> for FunctionId {
366    fn from(id: &u32) -> Self {
367        Self::new(*id)
368    }
369}
370
371impl From<FunctionId> for u32 {
372    fn from(id: FunctionId) -> Self {
373        id.0
374    }
375}
376
377#[derive(Clone, Copy, Debug, Display, Default, Hash, PartialOrd, PartialEq, Eq, Ord)]
378#[display("{user_id}")]
379pub struct UserId {
380    pub user_id: u32,
381}
382
383impl UserId {
384    pub const fn new(user_id: u32) -> Self {
385        UserId { user_id }
386    }
387
388    pub const fn placeholder() -> Self {
389        UserId {
390            user_id: OBJECT_ID_PLACEHOLDER,
391        }
392    }
393}
394
395impl From<u32> for UserId {
396    fn from(id: u32) -> Self {
397        Self::new(id)
398    }
399}
400
401impl From<&u32> for UserId {
402    fn from(id: &u32) -> Self {
403        Self::new(*id)
404    }
405}
406
407impl From<UserId> for u32 {
408    fn from(id: UserId) -> Self {
409        id.user_id
410    }
411}
412
413#[derive(Clone, Copy, Debug, Display, Default, Hash, PartialOrd, PartialEq, Eq, Ord)]
414pub struct ConnectionId(pub u32);
415
416impl ConnectionId {
417    pub const fn new(id: u32) -> Self {
418        ConnectionId(id)
419    }
420
421    pub const fn placeholder() -> Self {
422        ConnectionId(OBJECT_ID_PLACEHOLDER)
423    }
424
425    pub fn connection_id(&self) -> u32 {
426        self.0
427    }
428}
429
430impl From<u32> for ConnectionId {
431    fn from(id: u32) -> Self {
432        Self::new(id)
433    }
434}
435
436impl From<&u32> for ConnectionId {
437    fn from(id: &u32) -> Self {
438        Self::new(*id)
439    }
440}
441
442impl From<ConnectionId> for u32 {
443    fn from(id: ConnectionId) -> Self {
444        id.0
445    }
446}
447
448#[derive(Clone, Copy, Debug, Display, Default, Hash, PartialOrd, PartialEq, Eq, Ord)]
449pub struct SecretId(pub u32);
450
451impl SecretId {
452    pub const fn new(id: u32) -> Self {
453        SecretId(id)
454    }
455
456    pub const fn placeholder() -> Self {
457        SecretId(OBJECT_ID_PLACEHOLDER)
458    }
459
460    pub fn secret_id(&self) -> u32 {
461        self.0
462    }
463}
464
465impl From<u32> for SecretId {
466    fn from(id: u32) -> Self {
467        Self::new(id)
468    }
469}
470
471impl From<&u32> for SecretId {
472    fn from(id: &u32) -> Self {
473        Self::new(*id)
474    }
475}
476
477impl From<SecretId> for u32 {
478    fn from(id: SecretId) -> Self {
479        id.0
480    }
481}
482
483#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
484pub enum ConflictBehavior {
485    #[default]
486    NoCheck,
487    Overwrite,
488    IgnoreConflict,
489    DoUpdateIfNotNull,
490}
491
492#[macro_export]
493macro_rules! _checked_conflict_behaviors {
494    () => {
495        ConflictBehavior::Overwrite
496            | ConflictBehavior::IgnoreConflict
497            | ConflictBehavior::DoUpdateIfNotNull
498    };
499}
500pub use _checked_conflict_behaviors as checked_conflict_behaviors;
501
502impl ConflictBehavior {
503    pub fn from_protobuf(tb_conflict_behavior: &PbHandleConflictBehavior) -> Self {
504        match tb_conflict_behavior {
505            PbHandleConflictBehavior::Overwrite => ConflictBehavior::Overwrite,
506            PbHandleConflictBehavior::Ignore => ConflictBehavior::IgnoreConflict,
507            PbHandleConflictBehavior::DoUpdateIfNotNull => ConflictBehavior::DoUpdateIfNotNull,
508            // This is for backward compatibility, in the previous version
509            // `HandleConflictBehavior::Unspecified` represented `NoCheck`, so just treat it as `NoCheck`.
510            PbHandleConflictBehavior::NoCheck | PbHandleConflictBehavior::Unspecified => {
511                ConflictBehavior::NoCheck
512            }
513        }
514    }
515
516    pub fn to_protobuf(self) -> PbHandleConflictBehavior {
517        match self {
518            ConflictBehavior::NoCheck => PbHandleConflictBehavior::NoCheck,
519            ConflictBehavior::Overwrite => PbHandleConflictBehavior::Overwrite,
520            ConflictBehavior::IgnoreConflict => PbHandleConflictBehavior::Ignore,
521            ConflictBehavior::DoUpdateIfNotNull => PbHandleConflictBehavior::DoUpdateIfNotNull,
522        }
523    }
524
525    pub fn debug_to_string(self) -> String {
526        match self {
527            ConflictBehavior::NoCheck => "NoCheck".to_owned(),
528            ConflictBehavior::Overwrite => "Overwrite".to_owned(),
529            ConflictBehavior::IgnoreConflict => "IgnoreConflict".to_owned(),
530            ConflictBehavior::DoUpdateIfNotNull => "DoUpdateIfNotNull".to_owned(),
531        }
532    }
533}
534
535#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
536pub enum Engine {
537    #[default]
538    Hummock,
539    Iceberg,
540}
541
542impl Engine {
543    pub fn from_protobuf(engine: &PbEngine) -> Self {
544        match engine {
545            PbEngine::Hummock | PbEngine::Unspecified => Engine::Hummock,
546            PbEngine::Iceberg => Engine::Iceberg,
547        }
548    }
549
550    pub fn to_protobuf(self) -> PbEngine {
551        match self {
552            Engine::Hummock => PbEngine::Hummock,
553            Engine::Iceberg => PbEngine::Iceberg,
554        }
555    }
556
557    pub fn debug_to_string(self) -> String {
558        match self {
559            Engine::Hummock => "Hummock".to_owned(),
560            Engine::Iceberg => "Iceberg".to_owned(),
561        }
562    }
563}
564
565#[derive(Clone, Copy, Debug, Default, Display, Hash, PartialOrd, PartialEq, Eq, Ord)]
566pub enum StreamJobStatus {
567    #[default]
568    Creating,
569    Created,
570}
571
572impl StreamJobStatus {
573    pub fn from_proto(stream_job_status: PbStreamJobStatus) -> Self {
574        match stream_job_status {
575            PbStreamJobStatus::Creating => StreamJobStatus::Creating,
576            PbStreamJobStatus::Created | PbStreamJobStatus::Unspecified => StreamJobStatus::Created,
577        }
578    }
579
580    pub fn to_proto(self) -> PbStreamJobStatus {
581        match self {
582            StreamJobStatus::Creating => PbStreamJobStatus::Creating,
583            StreamJobStatus::Created => PbStreamJobStatus::Created,
584        }
585    }
586}
587
588#[derive(Clone, Copy, Debug, Display, Hash, PartialOrd, PartialEq, Eq, Ord, Default)]
589pub enum CreateType {
590    #[default]
591    Foreground,
592    Background,
593}
594
595impl CreateType {
596    pub fn from_proto(pb_create_type: PbCreateType) -> Self {
597        match pb_create_type {
598            PbCreateType::Foreground | PbCreateType::Unspecified => CreateType::Foreground,
599            PbCreateType::Background => CreateType::Background,
600        }
601    }
602
603    pub fn to_proto(self) -> PbCreateType {
604        match self {
605            CreateType::Foreground => PbCreateType::Foreground,
606            CreateType::Background => PbCreateType::Background,
607        }
608    }
609}
610
611#[derive(Clone, Debug)]
612pub enum AlterDatabaseParam {
613    // Barrier related parameters, per database.
614    // None represents the default value, which means it follows `SystemParams`.
615    BarrierIntervalMs(Option<u32>),
616    CheckpointFrequency(Option<u64>),
617}
618
619macro_rules! for_all_fragment_type_flags {
620    () => {
621        for_all_fragment_type_flags! {
622            {
623                Source,
624                Mview,
625                Sink,
626                Now,
627                StreamScan,
628                BarrierRecv,
629                Values,
630                Dml,
631                CdcFilter,
632                Skipped1,
633                SourceScan,
634                SnapshotBackfillStreamScan,
635                FsFetch,
636                CrossDbSnapshotBackfillStreamScan,
637                StreamCdcScan,
638                VectorIndexWrite,
639                UpstreamSinkUnion,
640                LocalityProvider
641            },
642            {},
643            0
644        }
645    };
646    (
647        {},
648        {
649            $(
650                {$flag:ident, $index:expr}
651            ),*
652        },
653        $next_index:expr
654    ) => {
655        #[derive(Clone, Copy, Debug, Display, Hash, PartialOrd, PartialEq, Eq)]
656        #[repr(u32)]
657        pub enum FragmentTypeFlag {
658            $(
659                $flag = (1 << $index),
660            )*
661        }
662
663        pub const FRAGMENT_TYPE_FLAG_LIST: [FragmentTypeFlag; $next_index] = [
664            $(
665                FragmentTypeFlag::$flag,
666            )*
667        ];
668
669        impl TryFrom<u32> for FragmentTypeFlag {
670            type Error = String;
671
672            fn try_from(value: u32) -> Result<Self, Self::Error> {
673                match value {
674                    $(
675                        value if value == (FragmentTypeFlag::$flag as u32) => Ok(FragmentTypeFlag::$flag),
676                    )*
677                    _ => Err(format!("Invalid FragmentTypeFlag value: {}", value)),
678                }
679            }
680        }
681
682        impl FragmentTypeFlag {
683            pub fn as_str_name(&self) -> &'static str {
684                match self {
685                    $(
686                        FragmentTypeFlag::$flag => paste::paste!{stringify!( [< $flag:snake:upper >] )},
687                    )*
688                }
689            }
690        }
691    };
692    (
693        {$first:ident $(, $rest:ident)*},
694        {
695            $(
696                {$flag:ident, $index:expr}
697            ),*
698        },
699        $next_index:expr
700    ) => {
701        for_all_fragment_type_flags! {
702            {$($rest),*},
703            {
704                $({$flag, $index},)*
705                {$first, $next_index}
706            },
707            $next_index + 1
708        }
709    };
710}
711
712for_all_fragment_type_flags!();
713
714impl FragmentTypeFlag {
715    pub fn raw_flag(flags: impl IntoIterator<Item = FragmentTypeFlag>) -> u32 {
716        flags.into_iter().fold(0, |acc, flag| acc | (flag as u32))
717    }
718
719    /// Fragments that may be affected by `BACKFILL_RATE_LIMIT`.
720    pub fn backfill_rate_limit_fragments() -> impl Iterator<Item = FragmentTypeFlag> {
721        [FragmentTypeFlag::SourceScan, FragmentTypeFlag::StreamScan].into_iter()
722    }
723
724    /// Fragments that may be affected by `SOURCE_RATE_LIMIT`.
725    /// Note: for `FsFetch`, old fragments don't have this flag set, so don't use this to check.
726    pub fn source_rate_limit_fragments() -> impl Iterator<Item = FragmentTypeFlag> {
727        [FragmentTypeFlag::Source, FragmentTypeFlag::FsFetch].into_iter()
728    }
729
730    /// Fragments that may be affected by `BACKFILL_RATE_LIMIT`.
731    pub fn sink_rate_limit_fragments() -> impl Iterator<Item = FragmentTypeFlag> {
732        [FragmentTypeFlag::Sink].into_iter()
733    }
734
735    /// Note: this doesn't include `FsFetch` created in old versions.
736    pub fn rate_limit_fragments() -> impl Iterator<Item = FragmentTypeFlag> {
737        Self::backfill_rate_limit_fragments()
738            .chain(Self::source_rate_limit_fragments())
739            .chain(Self::sink_rate_limit_fragments())
740    }
741
742    pub fn dml_rate_limit_fragments() -> impl Iterator<Item = FragmentTypeFlag> {
743        [FragmentTypeFlag::Dml].into_iter()
744    }
745}
746
747#[derive(Clone, Copy, Debug, Hash, PartialOrd, PartialEq, Eq, Default)]
748pub struct FragmentTypeMask(u32);
749
750impl Binary for FragmentTypeMask {
751    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
752        write!(f, "{:b}", self.0)
753    }
754}
755
756impl From<i32> for FragmentTypeMask {
757    fn from(value: i32) -> Self {
758        Self(value as u32)
759    }
760}
761
762impl From<u32> for FragmentTypeMask {
763    fn from(value: u32) -> Self {
764        Self(value)
765    }
766}
767
768impl From<FragmentTypeMask> for u32 {
769    fn from(value: FragmentTypeMask) -> Self {
770        value.0
771    }
772}
773
774impl From<FragmentTypeMask> for i32 {
775    fn from(value: FragmentTypeMask) -> Self {
776        value.0 as _
777    }
778}
779
780impl FragmentTypeMask {
781    pub fn empty() -> Self {
782        FragmentTypeMask(0)
783    }
784
785    pub fn add(&mut self, flag: FragmentTypeFlag) {
786        self.0 |= flag as u32;
787    }
788
789    pub fn contains_any(&self, flags: impl IntoIterator<Item = FragmentTypeFlag>) -> bool {
790        let flag = FragmentTypeFlag::raw_flag(flags);
791        (self.0 & flag) != 0
792    }
793
794    pub fn contains(&self, flag: FragmentTypeFlag) -> bool {
795        self.contains_any([flag])
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use itertools::Itertools;
802    use risingwave_common::catalog::FRAGMENT_TYPE_FLAG_LIST;
803
804    use crate::catalog::FragmentTypeFlag;
805
806    #[test]
807    fn test_all_fragment_type_flag() {
808        expect_test::expect![[r#"
809            [
810                (
811                    Source,
812                    1,
813                    "SOURCE",
814                ),
815                (
816                    Mview,
817                    2,
818                    "MVIEW",
819                ),
820                (
821                    Sink,
822                    4,
823                    "SINK",
824                ),
825                (
826                    Now,
827                    8,
828                    "NOW",
829                ),
830                (
831                    StreamScan,
832                    16,
833                    "STREAM_SCAN",
834                ),
835                (
836                    BarrierRecv,
837                    32,
838                    "BARRIER_RECV",
839                ),
840                (
841                    Values,
842                    64,
843                    "VALUES",
844                ),
845                (
846                    Dml,
847                    128,
848                    "DML",
849                ),
850                (
851                    CdcFilter,
852                    256,
853                    "CDC_FILTER",
854                ),
855                (
856                    Skipped1,
857                    512,
858                    "SKIPPED1",
859                ),
860                (
861                    SourceScan,
862                    1024,
863                    "SOURCE_SCAN",
864                ),
865                (
866                    SnapshotBackfillStreamScan,
867                    2048,
868                    "SNAPSHOT_BACKFILL_STREAM_SCAN",
869                ),
870                (
871                    FsFetch,
872                    4096,
873                    "FS_FETCH",
874                ),
875                (
876                    CrossDbSnapshotBackfillStreamScan,
877                    8192,
878                    "CROSS_DB_SNAPSHOT_BACKFILL_STREAM_SCAN",
879                ),
880                (
881                    StreamCdcScan,
882                    16384,
883                    "STREAM_CDC_SCAN",
884                ),
885                (
886                    VectorIndexWrite,
887                    32768,
888                    "VECTOR_INDEX_WRITE",
889                ),
890                (
891                    UpstreamSinkUnion,
892                    65536,
893                    "UPSTREAM_SINK_UNION",
894                ),
895                (
896                    LocalityProvider,
897                    131072,
898                    "LOCALITY_PROVIDER",
899                ),
900            ]
901        "#]]
902        .assert_debug_eq(
903            &FRAGMENT_TYPE_FLAG_LIST
904                .into_iter()
905                .map(|flag| (flag, flag as u32, flag.as_str_name()))
906                .collect_vec(),
907        );
908        for flag in FRAGMENT_TYPE_FLAG_LIST {
909            assert_eq!(FragmentTypeFlag::try_from(flag as u32).unwrap(), flag);
910        }
911    }
912}