risingwave_batch/worker_manager/
worker_node_manager.rs1use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, RwLock, RwLockReadGuard};
17use std::time::Duration;
18
19use rand::seq::IndexedRandom;
20use risingwave_common::bail;
21use risingwave_common::catalog::OBJECT_ID_PLACEHOLDER;
22use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
23use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
24use risingwave_pb::common::{WorkerNode, WorkerType};
25
26use crate::error::{BatchError, Result};
27
28pub(crate) type FragmentId = u32;
29
30pub struct WorkerNodeManager {
32 inner: RwLock<WorkerNodeManagerInner>,
33 worker_node_mask: Arc<RwLock<HashSet<u32>>>,
35}
36
37struct WorkerNodeManagerInner {
38 worker_nodes: Vec<WorkerNode>,
39 streaming_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
41 serving_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
43}
44
45pub type WorkerNodeManagerRef = Arc<WorkerNodeManager>;
46
47impl Default for WorkerNodeManager {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl WorkerNodeManager {
54 pub fn new() -> Self {
55 Self {
56 inner: RwLock::new(WorkerNodeManagerInner {
57 worker_nodes: Default::default(),
58 streaming_fragment_vnode_mapping: Default::default(),
59 serving_fragment_vnode_mapping: Default::default(),
60 }),
61 worker_node_mask: Arc::new(Default::default()),
62 }
63 }
64
65 pub fn mock(worker_nodes: Vec<WorkerNode>) -> Self {
67 let inner = RwLock::new(WorkerNodeManagerInner {
68 worker_nodes,
69 streaming_fragment_vnode_mapping: HashMap::new(),
70 serving_fragment_vnode_mapping: HashMap::new(),
71 });
72 Self {
73 inner,
74 worker_node_mask: Arc::new(Default::default()),
75 }
76 }
77
78 pub fn list_worker_nodes(&self) -> Vec<WorkerNode> {
79 self.inner
80 .read()
81 .unwrap()
82 .worker_nodes
83 .iter()
84 .filter(|w| w.r#type() == WorkerType::ComputeNode)
85 .cloned()
86 .collect()
87 }
88
89 fn list_serving_worker_nodes(&self) -> Vec<WorkerNode> {
90 self.list_worker_nodes()
91 .into_iter()
92 .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
93 .collect()
94 }
95
96 fn list_streaming_worker_nodes(&self) -> Vec<WorkerNode> {
97 self.list_worker_nodes()
98 .into_iter()
99 .filter(|w| w.property.as_ref().is_some_and(|p| p.is_streaming))
100 .collect()
101 }
102
103 pub fn add_worker_node(&self, node: WorkerNode) {
104 let mut write_guard = self.inner.write().unwrap();
105 match write_guard
106 .worker_nodes
107 .iter_mut()
108 .find(|w| w.id == node.id)
109 {
110 None => {
111 write_guard.worker_nodes.push(node);
113 }
114 Some(w) => {
115 *w = node;
117 }
118 }
119 }
120
121 pub fn remove_worker_node(&self, node: WorkerNode) {
122 let mut write_guard = self.inner.write().unwrap();
123 write_guard.worker_nodes.retain(|x| x.id != node.id);
124 }
125
126 pub fn refresh(
127 &self,
128 nodes: Vec<WorkerNode>,
129 streaming_mapping: HashMap<FragmentId, WorkerSlotMapping>,
130 serving_mapping: HashMap<FragmentId, WorkerSlotMapping>,
131 ) {
132 let mut write_guard = self.inner.write().unwrap();
133 tracing::debug!("Refresh worker nodes {:?}.", nodes);
134 tracing::debug!(
135 "Refresh streaming vnode mapping for fragments {:?}.",
136 streaming_mapping.keys()
137 );
138 tracing::debug!(
139 "Refresh serving vnode mapping for fragments {:?}.",
140 serving_mapping.keys()
141 );
142 write_guard.worker_nodes = nodes;
143 write_guard.streaming_fragment_vnode_mapping = streaming_mapping;
144 write_guard.serving_fragment_vnode_mapping = serving_mapping;
145 }
146
147 pub fn get_workers_by_worker_slot_ids(
151 &self,
152 worker_slot_ids: &[WorkerSlotId],
153 ) -> Result<Vec<WorkerNode>> {
154 if worker_slot_ids.is_empty() {
155 return Err(BatchError::EmptyWorkerNodes);
156 }
157
158 let guard = self.inner.read().unwrap();
159
160 let worker_index: HashMap<_, _> = guard.worker_nodes.iter().map(|w| (w.id, w)).collect();
161
162 let mut workers = Vec::with_capacity(worker_slot_ids.len());
163
164 for worker_slot_id in worker_slot_ids {
165 match worker_index.get(&worker_slot_id.worker_id()) {
166 Some(worker) => workers.push((*worker).clone()),
167 None => bail!(
168 "No worker node found for worker slot id: {}",
169 worker_slot_id
170 ),
171 }
172 }
173
174 Ok(workers)
175 }
176
177 pub fn get_streaming_fragment_mapping(
178 &self,
179 fragment_id: &FragmentId,
180 ) -> Result<WorkerSlotMapping> {
181 self.inner
182 .read()
183 .unwrap()
184 .streaming_fragment_vnode_mapping
185 .get(fragment_id)
186 .cloned()
187 .ok_or_else(|| BatchError::StreamingVnodeMappingNotFound(*fragment_id))
188 }
189
190 pub fn insert_streaming_fragment_mapping(
191 &self,
192 fragment_id: FragmentId,
193 vnode_mapping: WorkerSlotMapping,
194 ) {
195 if self
196 .inner
197 .write()
198 .unwrap()
199 .streaming_fragment_vnode_mapping
200 .try_insert(fragment_id, vnode_mapping)
201 .is_err()
202 {
203 tracing::info!(
204 "Previous batch vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
205 );
206 }
207 }
208
209 pub fn update_streaming_fragment_mapping(
210 &self,
211 fragment_id: FragmentId,
212 vnode_mapping: WorkerSlotMapping,
213 ) {
214 let mut guard = self.inner.write().unwrap();
215 if guard
216 .streaming_fragment_vnode_mapping
217 .insert(fragment_id, vnode_mapping)
218 .is_none()
219 {
220 tracing::info!(
221 "Previous vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
222 );
223 }
224 }
225
226 pub fn remove_streaming_fragment_mapping(&self, fragment_id: &FragmentId) {
227 let mut guard = self.inner.write().unwrap();
228
229 let res = guard.streaming_fragment_vnode_mapping.remove(fragment_id);
230 match &res {
231 Some(_) => {}
232 None if OBJECT_ID_PLACEHOLDER == *fragment_id => {
233 }
235 None => {
236 tracing::warn!(fragment_id, "Streaming vnode mapping not found");
237 }
238 };
239 }
240
241 fn serving_fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
243 self.inner
244 .read()
245 .unwrap()
246 .get_serving_fragment_mapping(fragment_id)
247 .ok_or_else(|| BatchError::ServingVnodeMappingNotFound(fragment_id))
248 }
249
250 pub fn set_serving_fragment_mapping(&self, mappings: HashMap<FragmentId, WorkerSlotMapping>) {
251 let mut guard = self.inner.write().unwrap();
252 tracing::debug!(
253 "Set serving vnode mapping for fragments {:?}",
254 mappings.keys()
255 );
256 guard.serving_fragment_vnode_mapping = mappings;
257 }
258
259 pub fn upsert_serving_fragment_mapping(
260 &self,
261 mappings: HashMap<FragmentId, WorkerSlotMapping>,
262 ) {
263 let mut guard = self.inner.write().unwrap();
264 tracing::debug!(
265 "Upsert serving vnode mapping for fragments {:?}",
266 mappings.keys()
267 );
268 for (fragment_id, mapping) in mappings {
269 guard
270 .serving_fragment_vnode_mapping
271 .insert(fragment_id, mapping);
272 }
273 }
274
275 pub fn remove_serving_fragment_mapping(&self, fragment_ids: &[FragmentId]) {
276 let mut guard = self.inner.write().unwrap();
277 tracing::debug!(
278 "Delete serving vnode mapping for fragments {:?}",
279 fragment_ids
280 );
281 for fragment_id in fragment_ids {
282 guard.serving_fragment_vnode_mapping.remove(fragment_id);
283 }
284 }
285
286 fn worker_node_mask(&self) -> RwLockReadGuard<'_, HashSet<u32>> {
287 self.worker_node_mask.read().unwrap()
288 }
289
290 pub fn mask_worker_node(&self, worker_node_id: u32, duration: Duration) {
291 tracing::info!(
292 "Mask worker node {} for {:?} temporarily",
293 worker_node_id,
294 duration
295 );
296 let mut worker_node_mask = self.worker_node_mask.write().unwrap();
297 if worker_node_mask.contains(&worker_node_id) {
298 return;
299 }
300 worker_node_mask.insert(worker_node_id);
301 let worker_node_mask_ref = self.worker_node_mask.clone();
302 tokio::spawn(async move {
303 tokio::time::sleep(duration).await;
304 worker_node_mask_ref
305 .write()
306 .unwrap()
307 .remove(&worker_node_id);
308 });
309 }
310}
311
312impl WorkerNodeManagerInner {
313 fn get_serving_fragment_mapping(&self, fragment_id: FragmentId) -> Option<WorkerSlotMapping> {
314 self.serving_fragment_vnode_mapping
315 .get(&fragment_id)
316 .cloned()
317 }
318}
319
320#[derive(Clone)]
322pub struct WorkerNodeSelector {
323 pub manager: WorkerNodeManagerRef,
324 enable_barrier_read: bool,
325}
326
327impl WorkerNodeSelector {
328 pub fn new(manager: WorkerNodeManagerRef, enable_barrier_read: bool) -> Self {
329 Self {
330 manager,
331 enable_barrier_read,
332 }
333 }
334
335 pub fn worker_node_count(&self) -> usize {
336 if self.enable_barrier_read {
337 self.manager.list_streaming_worker_nodes().len()
338 } else {
339 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
340 .len()
341 }
342 }
343
344 pub fn schedule_unit_count(&self) -> usize {
345 let worker_nodes = if self.enable_barrier_read {
346 self.manager.list_streaming_worker_nodes()
347 } else {
348 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
349 };
350 worker_nodes
351 .iter()
352 .map(|node| node.compute_node_parallelism())
353 .sum()
354 }
355
356 pub fn fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
357 if self.enable_barrier_read {
358 self.manager.get_streaming_fragment_mapping(&fragment_id)
359 } else {
360 let mapping = (self.manager.serving_fragment_mapping(fragment_id)).or_else(|_| {
361 tracing::warn!(
362 fragment_id,
363 "Serving fragment mapping not found, fall back to streaming one."
364 );
365 self.manager.get_streaming_fragment_mapping(&fragment_id)
366 })?;
367
368 if self.manager.worker_node_mask().is_empty() {
370 Ok(mapping)
371 } else {
372 let workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
373 let max_parallelism = mapping.to_single().map(|_| 1);
375 let masked_mapping =
376 place_vnode(Some(&mapping), &workers, max_parallelism, mapping.len())
377 .ok_or_else(|| BatchError::EmptyWorkerNodes)?;
378 Ok(masked_mapping)
379 }
380 }
381 }
382
383 pub fn next_random_worker(&self) -> Result<WorkerNode> {
384 let worker_nodes = if self.enable_barrier_read {
385 self.manager.list_streaming_worker_nodes()
386 } else {
387 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
388 };
389 worker_nodes
390 .choose(&mut rand::rng())
391 .ok_or_else(|| BatchError::EmptyWorkerNodes)
392 .map(|w| (*w).clone())
393 }
394
395 fn apply_worker_node_mask(&self, origin: Vec<WorkerNode>) -> Vec<WorkerNode> {
396 let mask = self.manager.worker_node_mask();
397 if origin.iter().all(|w| mask.contains(&w.id)) {
398 return origin;
399 }
400 origin
401 .into_iter()
402 .filter(|w| !mask.contains(&w.id))
403 .collect()
404 }
405}
406
407#[cfg(test)]
408mod tests {
409
410 use risingwave_common::util::addr::HostAddr;
411 use risingwave_pb::common::worker_node;
412 use risingwave_pb::common::worker_node::Property;
413
414 #[test]
415 fn test_worker_node_manager() {
416 use super::*;
417
418 let manager = WorkerNodeManager::mock(vec![]);
419 assert_eq!(manager.list_serving_worker_nodes().len(), 0);
420 assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
421 assert_eq!(manager.list_worker_nodes(), vec![]);
422
423 let worker_nodes = vec![
424 WorkerNode {
425 id: 1,
426 r#type: WorkerType::ComputeNode as i32,
427 host: Some(HostAddr::try_from("127.0.0.1:1234").unwrap().to_protobuf()),
428 state: worker_node::State::Running as i32,
429 property: Some(Property {
430 is_unschedulable: false,
431 is_serving: true,
432 is_streaming: true,
433 ..Default::default()
434 }),
435 transactional_id: Some(1),
436 ..Default::default()
437 },
438 WorkerNode {
439 id: 2,
440 r#type: WorkerType::ComputeNode as i32,
441 host: Some(HostAddr::try_from("127.0.0.1:1235").unwrap().to_protobuf()),
442 state: worker_node::State::Running as i32,
443 property: Some(Property {
444 is_unschedulable: false,
445 is_serving: true,
446 is_streaming: false,
447 ..Default::default()
448 }),
449 transactional_id: Some(2),
450 ..Default::default()
451 },
452 ];
453 worker_nodes
454 .iter()
455 .for_each(|w| manager.add_worker_node(w.clone()));
456 assert_eq!(manager.list_serving_worker_nodes().len(), 2);
457 assert_eq!(manager.list_streaming_worker_nodes().len(), 1);
458 assert_eq!(manager.list_worker_nodes(), worker_nodes);
459
460 manager.remove_worker_node(worker_nodes[0].clone());
461 assert_eq!(manager.list_serving_worker_nodes().len(), 1);
462 assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
463 assert_eq!(
464 manager.list_worker_nodes(),
465 worker_nodes.as_slice()[1..].to_vec()
466 );
467 }
468}