risingwave_frontend/stream_fragmenter/graph/
fragment_graph.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::rc::Rc;
17
18use risingwave_pb::stream_plan::stream_fragment_graph::{
19    StreamFragment as StreamFragmentProto, StreamFragmentEdge as StreamFragmentEdgeProto,
20};
21use risingwave_pb::stream_plan::{
22    DispatchStrategy, FragmentTypeFlag, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode,
23};
24use thiserror_ext::AsReport;
25
26pub type LocalFragmentId = u32;
27
28/// [`StreamFragment`] represent a fragment node in fragment DAG.
29#[derive(Clone, Debug)]
30pub struct StreamFragment {
31    /// the allocated fragment id.
32    pub fragment_id: LocalFragmentId,
33
34    /// root stream node in this fragment.
35    pub node: Option<Box<StreamNode>>,
36
37    /// Bitwise-OR of type Flags of this fragment.
38    pub fragment_type_mask: u32,
39
40    /// Mark whether this fragment requires exactly one actor.
41    pub requires_singleton: bool,
42
43    /// Number of table ids (stateful states) for this fragment.
44    pub table_ids_cnt: u32,
45
46    /// Mark the upstream table ids of this fragment.
47    pub upstream_table_ids: Vec<u32>,
48}
49
50/// An edge between the nodes in the fragment graph.
51#[derive(Debug, Clone)]
52pub struct StreamFragmentEdge {
53    /// Dispatch strategy for the fragment.
54    pub dispatch_strategy: DispatchStrategy,
55
56    /// A unique identifier of this edge. Generally it should be exchange node's operator id. When
57    /// rewriting fragments into delta joins or when inserting 1-to-1 exchange, there will be
58    /// virtual links generated.
59    pub link_id: u64,
60}
61
62impl StreamFragment {
63    pub fn new(fragment_id: LocalFragmentId) -> Self {
64        Self {
65            fragment_id,
66            fragment_type_mask: FragmentTypeFlag::FragmentUnspecified as u32,
67            requires_singleton: false,
68            node: None,
69            table_ids_cnt: 0,
70            upstream_table_ids: vec![],
71        }
72    }
73
74    pub fn to_protobuf(&self) -> StreamFragmentProto {
75        StreamFragmentProto {
76            fragment_id: self.fragment_id,
77            node: self.node.clone().map(|n| *n),
78            fragment_type_mask: self.fragment_type_mask,
79            requires_singleton: self.requires_singleton,
80            table_ids_cnt: self.table_ids_cnt,
81            upstream_table_ids: self.upstream_table_ids.clone(),
82        }
83    }
84}
85
86/// [`StreamFragmentGraph`] stores a fragment graph (DAG).
87#[derive(Default)]
88pub struct StreamFragmentGraph {
89    /// stores all the fragments in the graph.
90    fragments: HashMap<LocalFragmentId, Rc<StreamFragment>>,
91
92    /// stores edges between fragments: (upstream, downstream) => edge.
93    edges: HashMap<(LocalFragmentId, LocalFragmentId), StreamFragmentEdgeProto>,
94}
95
96impl StreamFragmentGraph {
97    pub fn to_protobuf(&self) -> StreamFragmentGraphProto {
98        StreamFragmentGraphProto {
99            fragments: self
100                .fragments
101                .iter()
102                .map(|(k, v)| (*k, v.to_protobuf()))
103                .collect(),
104            edges: self.edges.values().cloned().collect(),
105
106            // Following fields will be filled later in `build_graph` based on session context.
107            ctx: None,
108            dependent_table_ids: vec![],
109            table_ids_cnt: 0,
110            parallelism: None,
111            max_parallelism: 0,
112        }
113    }
114
115    /// Adds a fragment to the graph.
116    pub fn add_fragment(&mut self, stream_fragment: Rc<StreamFragment>) {
117        let id = stream_fragment.fragment_id;
118        let ret = self.fragments.insert(id, stream_fragment);
119        assert!(ret.is_none(), "fragment already exists: {:?}", id);
120    }
121
122    pub fn get_fragment(&self, fragment_id: &LocalFragmentId) -> Option<&Rc<StreamFragment>> {
123        self.fragments.get(fragment_id)
124    }
125
126    /// Links upstream to downstream in the graph.
127    pub fn add_edge(
128        &mut self,
129        upstream_id: LocalFragmentId,
130        downstream_id: LocalFragmentId,
131        edge: StreamFragmentEdge,
132    ) {
133        self.try_add_edge(upstream_id, downstream_id, edge).unwrap();
134    }
135
136    /// Try to link upstream to downstream in the graph.
137    ///
138    /// If the edge between upstream and downstream already exists, return an error.
139    pub fn try_add_edge(
140        &mut self,
141        upstream_id: LocalFragmentId,
142        downstream_id: LocalFragmentId,
143        edge: StreamFragmentEdge,
144    ) -> Result<(), String> {
145        let edge = StreamFragmentEdgeProto {
146            upstream_id,
147            downstream_id,
148            dispatch_strategy: Some(edge.dispatch_strategy),
149            link_id: edge.link_id,
150        };
151
152        self.edges
153            .try_insert((upstream_id, downstream_id), edge)
154            .map(|_| ())
155            .map_err(|e| {
156                format!(
157                    "edge between {} and {} already exists: {}",
158                    upstream_id,
159                    downstream_id,
160                    e.to_report_string()
161                )
162            })
163    }
164}