risingwave_stream/task/
mod.rs1use std::collections::{HashMap, HashSet};
16
17use anyhow::anyhow;
18use parking_lot::{MappedMutexGuard, Mutex, MutexGuard, RwLock};
19use risingwave_common::config::StreamingConfig;
20use risingwave_common::util::addr::HostAddr;
21use risingwave_pb::common::ActorInfo;
22use risingwave_rpc_client::ComputeClientPoolRef;
23
24use crate::error::StreamResult;
25use crate::executor::exchange::permit::{self, Receiver, Sender};
26
27mod barrier_manager;
28mod env;
29mod stream_manager;
30
31pub use barrier_manager::*;
32pub use env::*;
33use risingwave_common::catalog::DatabaseId;
34pub use stream_manager::*;
35
36pub type ConsumableChannelPair = (Option<Sender>, Option<Receiver>);
37pub type ActorId = u32;
38pub type FragmentId = u32;
39pub type DispatcherId = u64;
40pub type UpDownActorIds = (ActorId, ActorId);
41pub type UpDownFragmentIds = (FragmentId, FragmentId);
42
43#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
44pub(crate) struct PartialGraphId(u32);
45
46#[cfg(test)]
47pub(crate) const TEST_DATABASE_ID: risingwave_common::catalog::DatabaseId =
48 risingwave_common::catalog::DatabaseId::new(u32::MAX);
49
50#[cfg(test)]
51pub(crate) const TEST_PARTIAL_GRAPH_ID: PartialGraphId = PartialGraphId(u32::MAX);
52
53impl PartialGraphId {
54 fn new(id: u32) -> Self {
55 Self(id)
56 }
57}
58
59impl From<PartialGraphId> for u32 {
60 fn from(val: PartialGraphId) -> u32 {
61 val.0
62 }
63}
64
65pub struct SharedContext {
72 pub(crate) database_id: DatabaseId,
73 term_id: String,
74
75 channel_map: Mutex<HashMap<UpDownActorIds, ConsumableChannelPair>>,
93
94 actor_infos: RwLock<HashMap<ActorId, ActorInfo>>,
96
97 pub(crate) addr: HostAddr,
103
104 pub(crate) compute_client_pool: ComputeClientPoolRef,
108
109 pub(crate) config: StreamingConfig,
110}
111
112impl std::fmt::Debug for SharedContext {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("SharedContext")
115 .field("addr", &self.addr)
116 .finish_non_exhaustive()
117 }
118}
119
120impl SharedContext {
121 pub fn new(database_id: DatabaseId, env: &StreamEnvironment, term_id: String) -> Self {
122 Self {
123 database_id,
124 term_id,
125 channel_map: Default::default(),
126 actor_infos: Default::default(),
127 addr: env.server_address().clone(),
128 config: env.config().as_ref().to_owned(),
129 compute_client_pool: env.client_pool(),
130 }
131 }
132
133 pub fn term_id(&self) -> String {
134 self.term_id.clone()
135 }
136
137 #[cfg(test)]
138 pub fn for_test() -> Self {
139 use std::sync::Arc;
140
141 use risingwave_common::config::StreamingDeveloperConfig;
142 use risingwave_rpc_client::ComputeClientPool;
143
144 Self {
145 database_id: TEST_DATABASE_ID,
146 term_id: "for_test".into(),
147 channel_map: Default::default(),
148 actor_infos: Default::default(),
149 addr: LOCAL_TEST_ADDR.clone(),
150 config: StreamingConfig {
151 developer: StreamingDeveloperConfig {
152 exchange_initial_permits: permit::for_test::INITIAL_PERMITS,
153 exchange_batched_permits: permit::for_test::BATCHED_PERMITS,
154 exchange_concurrent_barriers: permit::for_test::CONCURRENT_BARRIERS,
155 ..Default::default()
156 },
157 ..Default::default()
158 },
159 compute_client_pool: Arc::new(ComputeClientPool::for_test()),
160 }
161 }
162
163 fn get_or_insert_channels(
166 &self,
167 ids: UpDownActorIds,
168 ) -> MappedMutexGuard<'_, ConsumableChannelPair> {
169 MutexGuard::map(self.channel_map.lock(), |map| {
170 map.entry(ids).or_insert_with(|| {
171 let (tx, rx) = permit::channel(
172 self.config.developer.exchange_initial_permits,
173 self.config.developer.exchange_batched_permits,
174 self.config.developer.exchange_concurrent_barriers,
175 );
176 (Some(tx), Some(rx))
177 })
178 })
179 }
180
181 pub fn take_sender(&self, ids: &UpDownActorIds) -> StreamResult<Sender> {
182 self.get_or_insert_channels(*ids)
183 .0
184 .take()
185 .ok_or_else(|| anyhow!("sender for {ids:?} has already been taken").into())
186 }
187
188 pub fn take_receiver(&self, ids: UpDownActorIds) -> StreamResult<Receiver> {
189 self.get_or_insert_channels(ids)
190 .1
191 .take()
192 .ok_or_else(|| anyhow!("receiver for {ids:?} has already been taken").into())
193 }
194
195 pub fn get_actor_info(&self, actor_id: &ActorId) -> StreamResult<ActorInfo> {
196 self.actor_infos
197 .read()
198 .get(actor_id)
199 .cloned()
200 .ok_or_else(|| anyhow!("actor {} not found in info table", actor_id).into())
201 }
202
203 pub fn config(&self) -> &StreamingConfig {
204 &self.config
205 }
206
207 pub(super) fn drop_actors(&self, actors: &HashSet<ActorId>) {
208 self.channel_map
209 .lock()
210 .retain(|(up_id, _), _| !actors.contains(up_id));
211 let mut actor_infos = self.actor_infos.write();
212 for actor_id in actors {
213 actor_infos.remove(actor_id);
214 }
215 }
216
217 pub(crate) fn add_actors(&self, new_actor_infos: impl Iterator<Item = ActorInfo>) {
218 let mut actor_infos = self.actor_infos.write();
219 for actor in new_actor_infos {
220 if let Some(prev_actor) = actor_infos.get(&actor.actor_id) {
221 if cfg!(debug_assertions) {
222 panic!("duplicate actor info: {:?} {:?}", actor, actor_infos);
223 }
224 if prev_actor != &actor {
225 warn!(
226 ?prev_actor,
227 ?actor,
228 "add actor again but have different actor info. ignored"
229 );
230 }
231 } else {
232 actor_infos.insert(actor.actor_id, actor);
233 }
234 }
235 }
236}