risingwave_batch/worker_manager/
worker_node_manager.rs1use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, RwLock, RwLockReadGuard};
17use std::time::Duration;
18
19use rand::seq::IndexedRandom;
20use risingwave_common::bail;
21use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
22use risingwave_common::id::{FragmentId, WorkerId};
23use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
24use risingwave_pb::common::{WorkerNode, WorkerType};
25
26use crate::error::{BatchError, Result};
27
28pub struct WorkerNodeManager {
30 inner: RwLock<WorkerNodeManagerInner>,
31 worker_node_mask: Arc<RwLock<HashSet<WorkerId>>>,
33}
34
35struct WorkerNodeManagerInner {
36 worker_nodes: HashMap<WorkerId, WorkerNode>,
37 streaming_fragment_vnode_mapping: Option<HashMap<FragmentId, WorkerSlotMapping>>,
39 serving_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
41}
42
43pub type WorkerNodeManagerRef = Arc<WorkerNodeManager>;
44
45impl Default for WorkerNodeManager {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl WorkerNodeManager {
52 pub fn new() -> Self {
53 Self {
54 inner: RwLock::new(WorkerNodeManagerInner {
55 worker_nodes: Default::default(),
56 streaming_fragment_vnode_mapping: None,
57 serving_fragment_vnode_mapping: Default::default(),
58 }),
59 worker_node_mask: Arc::new(Default::default()),
60 }
61 }
62
63 pub fn mock(worker_nodes: Vec<WorkerNode>) -> Self {
65 let worker_nodes = worker_nodes.into_iter().map(|w| (w.id, w)).collect();
66 let inner = RwLock::new(WorkerNodeManagerInner {
67 worker_nodes,
68 streaming_fragment_vnode_mapping: None,
69 serving_fragment_vnode_mapping: HashMap::new(),
70 });
71 Self {
72 inner,
73 worker_node_mask: Arc::new(Default::default()),
74 }
75 }
76
77 pub fn list_compute_nodes(&self) -> Vec<WorkerNode> {
78 self.inner
79 .read()
80 .unwrap()
81 .worker_nodes
82 .values()
83 .filter(|w| w.r#type() == WorkerType::ComputeNode)
84 .cloned()
85 .collect()
86 }
87
88 pub fn list_frontend_nodes(&self) -> Vec<WorkerNode> {
89 self.inner
90 .read()
91 .unwrap()
92 .worker_nodes
93 .values()
94 .filter(|w| w.r#type() == WorkerType::Frontend)
95 .cloned()
96 .collect()
97 }
98
99 fn list_serving_worker_nodes(&self) -> Vec<WorkerNode> {
100 self.list_compute_nodes()
101 .into_iter()
102 .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
103 .collect()
104 }
105
106 fn list_streaming_worker_nodes(&self) -> Vec<WorkerNode> {
107 self.list_compute_nodes()
108 .into_iter()
109 .filter(|w| w.property.as_ref().is_some_and(|p| p.is_streaming))
110 .collect()
111 }
112
113 pub fn add_worker_node(&self, node: WorkerNode) {
114 let mut write_guard = self.inner.write().unwrap();
115 write_guard.worker_nodes.insert(node.id, node);
116 }
117
118 pub fn remove_worker_node(&self, node: WorkerNode) {
119 let mut write_guard = self.inner.write().unwrap();
120 write_guard.worker_nodes.remove(&node.id);
121 }
122
123 pub fn refresh(
124 &self,
125 nodes: Vec<WorkerNode>,
126 streaming_mapping: HashMap<FragmentId, WorkerSlotMapping>,
127 serving_mapping: HashMap<FragmentId, WorkerSlotMapping>,
128 ) {
129 let mut write_guard = self.inner.write().unwrap();
130 tracing::debug!("Refresh worker nodes {:?}.", nodes);
131 tracing::debug!(
132 "Refresh streaming vnode mapping for fragments {:?}.",
133 streaming_mapping.keys()
134 );
135 tracing::debug!(
136 "Refresh serving vnode mapping for fragments {:?}.",
137 serving_mapping.keys()
138 );
139 write_guard.worker_nodes = nodes.into_iter().map(|w| (w.id, w)).collect();
140 write_guard.streaming_fragment_vnode_mapping = Some(streaming_mapping);
141 write_guard.serving_fragment_vnode_mapping = serving_mapping;
142 }
143
144 pub fn get_workers_by_worker_slot_ids(
148 &self,
149 worker_slot_ids: &[WorkerSlotId],
150 ) -> Result<Vec<WorkerNode>> {
151 if worker_slot_ids.is_empty() {
152 return Err(BatchError::EmptyWorkerNodes);
153 }
154 let guard = self.inner.read().unwrap();
155 let mut workers = Vec::with_capacity(worker_slot_ids.len());
156 for worker_slot_id in worker_slot_ids {
157 match guard.worker_nodes.get(&worker_slot_id.worker_id()) {
158 Some(worker) => workers.push((*worker).clone()),
159 None => bail!(
160 "No worker node found for worker slot id: {}",
161 worker_slot_id
162 ),
163 }
164 }
165
166 Ok(workers)
167 }
168
169 pub fn get_streaming_fragment_mapping(
170 &self,
171 fragment_id: &FragmentId,
172 ) -> Result<WorkerSlotMapping> {
173 let guard = self.inner.read().unwrap();
174
175 let Some(streaming_mapping) = guard.streaming_fragment_vnode_mapping.as_ref() else {
176 return Err(BatchError::StreamingVnodeMappingNotInitialized);
177 };
178
179 streaming_mapping
180 .get(fragment_id)
181 .cloned()
182 .ok_or_else(|| BatchError::StreamingVnodeMappingNotFound(*fragment_id))
183 }
184
185 pub fn insert_streaming_fragment_mapping(
186 &self,
187 fragment_id: FragmentId,
188 vnode_mapping: WorkerSlotMapping,
189 ) {
190 let mut guard = self.inner.write().unwrap();
191 let mapping = guard
192 .streaming_fragment_vnode_mapping
193 .get_or_insert_with(HashMap::new);
194 if mapping.try_insert(fragment_id, vnode_mapping).is_err() {
195 tracing::info!(
196 "Previous batch vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
197 );
198 }
199 }
200
201 pub fn update_streaming_fragment_mapping(
202 &self,
203 fragment_id: FragmentId,
204 vnode_mapping: WorkerSlotMapping,
205 ) {
206 let mut guard = self.inner.write().unwrap();
207 let mapping = guard
208 .streaming_fragment_vnode_mapping
209 .get_or_insert_with(HashMap::new);
210 if mapping.insert(fragment_id, vnode_mapping).is_none() {
211 tracing::info!(
212 "Previous vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
213 );
214 }
215 }
216
217 pub fn remove_streaming_fragment_mapping(&self, fragment_id: &FragmentId) {
218 let mut guard = self.inner.write().unwrap();
219
220 let res = guard
221 .streaming_fragment_vnode_mapping
222 .as_mut()
223 .and_then(|mapping| mapping.remove(fragment_id));
224 match &res {
225 Some(_) => {}
226 None if fragment_id.is_placeholder() => {
227 }
229 None => {
230 tracing::warn!(%fragment_id, "Streaming vnode mapping not found");
231 }
232 };
233 }
234
235 fn serving_fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
237 self.inner
238 .read()
239 .unwrap()
240 .get_serving_fragment_mapping(fragment_id)
241 .ok_or_else(|| BatchError::ServingVnodeMappingNotFound(fragment_id))
242 }
243
244 pub fn set_serving_fragment_mapping(&self, mappings: HashMap<FragmentId, WorkerSlotMapping>) {
245 let mut guard = self.inner.write().unwrap();
246 tracing::debug!(
247 "Set serving vnode mapping for fragments {:?}",
248 mappings.keys()
249 );
250 guard.serving_fragment_vnode_mapping = mappings;
251 }
252
253 pub fn upsert_serving_fragment_mapping(
254 &self,
255 mappings: HashMap<FragmentId, WorkerSlotMapping>,
256 ) {
257 let mut guard = self.inner.write().unwrap();
258 tracing::debug!(
259 "Upsert serving vnode mapping for fragments {:?}",
260 mappings.keys()
261 );
262 for (fragment_id, mapping) in mappings {
263 guard
264 .serving_fragment_vnode_mapping
265 .insert(fragment_id, mapping);
266 }
267 }
268
269 pub fn remove_serving_fragment_mapping(&self, fragment_ids: &[FragmentId]) {
270 let mut guard = self.inner.write().unwrap();
271 tracing::debug!(
272 "Delete serving vnode mapping for fragments {:?}",
273 fragment_ids
274 );
275 for fragment_id in fragment_ids {
276 guard.serving_fragment_vnode_mapping.remove(fragment_id);
277 }
278 }
279
280 fn worker_node_mask(&self) -> RwLockReadGuard<'_, HashSet<WorkerId>> {
281 self.worker_node_mask.read().unwrap()
282 }
283
284 pub fn mask_worker_node(&self, worker_node_id: WorkerId, duration: Duration) {
285 tracing::info!(
286 "Mask worker node {} for {:?} temporarily",
287 worker_node_id,
288 duration
289 );
290 let mut worker_node_mask = self.worker_node_mask.write().unwrap();
291 if worker_node_mask.contains(&worker_node_id) {
292 return;
293 }
294 worker_node_mask.insert(worker_node_id);
295 let worker_node_mask_ref = self.worker_node_mask.clone();
296 tokio::spawn(async move {
297 tokio::time::sleep(duration).await;
298 worker_node_mask_ref
299 .write()
300 .unwrap()
301 .remove(&worker_node_id);
302 });
303 }
304
305 pub fn worker_node(&self, worker_id: WorkerId) -> Option<WorkerNode> {
306 self.inner.read().unwrap().worker_node(worker_id)
307 }
308}
309
310impl WorkerNodeManagerInner {
311 fn get_serving_fragment_mapping(&self, fragment_id: FragmentId) -> Option<WorkerSlotMapping> {
312 self.serving_fragment_vnode_mapping
313 .get(&fragment_id)
314 .cloned()
315 }
316
317 fn worker_node(&self, worker_id: WorkerId) -> Option<WorkerNode> {
318 self.worker_nodes.get(&worker_id).cloned()
319 }
320}
321
322#[derive(Clone)]
324pub struct WorkerNodeSelector {
325 pub manager: WorkerNodeManagerRef,
326 enable_barrier_read: bool,
327}
328
329impl WorkerNodeSelector {
330 pub fn new(manager: WorkerNodeManagerRef, enable_barrier_read: bool) -> Self {
331 Self {
332 manager,
333 enable_barrier_read,
334 }
335 }
336
337 pub fn worker_node_count(&self) -> usize {
338 if self.enable_barrier_read {
339 self.manager.list_streaming_worker_nodes().len()
340 } else {
341 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
342 .len()
343 }
344 }
345
346 pub fn schedule_unit_count(&self) -> usize {
347 let worker_nodes = if self.enable_barrier_read {
348 self.manager.list_streaming_worker_nodes()
349 } else {
350 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
351 };
352 worker_nodes
353 .iter()
354 .map(|node| node.compute_node_parallelism())
355 .sum()
356 }
357
358 pub fn fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
359 if self.enable_barrier_read {
360 self.manager.get_streaming_fragment_mapping(&fragment_id)
361 } else {
362 let mapping = (self.manager.serving_fragment_mapping(fragment_id)).or_else(|_| {
363 tracing::warn!(
364 %fragment_id,
365 "Serving fragment mapping not found, fall back to streaming one."
366 );
367 self.manager.get_streaming_fragment_mapping(&fragment_id)
368 })?;
369
370 if self.manager.worker_node_mask().is_empty() {
372 Ok(mapping)
373 } else {
374 let workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
375 let max_parallelism = mapping.to_single().map(|_| 1);
377 let masked_mapping =
379 place_vnode(Some(&mapping), &workers, max_parallelism, mapping.len())
380 .ok_or_else(|| BatchError::EmptyWorkerNodes)?;
381 Ok(masked_mapping)
382 }
383 }
384 }
385
386 pub fn next_random_worker(&self) -> Result<WorkerNode> {
387 let worker_nodes = if self.enable_barrier_read {
388 self.manager.list_streaming_worker_nodes()
389 } else {
390 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
391 };
392 worker_nodes
393 .choose(&mut rand::rng())
394 .ok_or_else(|| BatchError::EmptyWorkerNodes)
395 .map(|w| (*w).clone())
396 }
397
398 fn apply_worker_node_mask(&self, origin: Vec<WorkerNode>) -> Vec<WorkerNode> {
399 let mask = self.manager.worker_node_mask();
400 if origin.iter().all(|w| mask.contains(&w.id)) {
401 return origin;
402 }
403 origin
404 .into_iter()
405 .filter(|w| !mask.contains(&w.id))
406 .collect()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use itertools::Itertools;
413 use risingwave_common::util::addr::HostAddr;
414 use risingwave_pb::common::worker_node;
415 use risingwave_pb::common::worker_node::Property;
416
417 #[test]
418 fn test_worker_node_manager() {
419 use super::*;
420
421 let manager = WorkerNodeManager::mock(vec![]);
422 assert_eq!(manager.list_serving_worker_nodes().len(), 0);
423 assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
424 assert_eq!(manager.list_compute_nodes(), vec![]);
425
426 let worker_nodes = vec![
427 WorkerNode {
428 id: 1.into(),
429 r#type: WorkerType::ComputeNode as i32,
430 host: Some(HostAddr::try_from("127.0.0.1:1234").unwrap().to_protobuf()),
431 state: worker_node::State::Running as i32,
432 property: Some(Property {
433 is_unschedulable: false,
434 is_serving: true,
435 is_streaming: true,
436 ..Default::default()
437 }),
438 transactional_id: Some(1),
439 ..Default::default()
440 },
441 WorkerNode {
442 id: 2.into(),
443 r#type: WorkerType::ComputeNode as i32,
444 host: Some(HostAddr::try_from("127.0.0.1:1235").unwrap().to_protobuf()),
445 state: worker_node::State::Running as i32,
446 property: Some(Property {
447 is_unschedulable: false,
448 is_serving: true,
449 is_streaming: false,
450 ..Default::default()
451 }),
452 transactional_id: Some(2),
453 ..Default::default()
454 },
455 ];
456 worker_nodes
457 .iter()
458 .for_each(|w| manager.add_worker_node(w.clone()));
459 assert_eq!(manager.list_serving_worker_nodes().len(), 2);
460 assert_eq!(manager.list_streaming_worker_nodes().len(), 1);
461 assert_eq!(
462 manager
463 .list_compute_nodes()
464 .into_iter()
465 .sorted_by_key(|w| w.id)
466 .collect_vec(),
467 worker_nodes
468 );
469
470 manager.remove_worker_node(worker_nodes[0].clone());
471 assert_eq!(manager.list_serving_worker_nodes().len(), 1);
472 assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
473 assert_eq!(
474 manager
475 .list_compute_nodes()
476 .into_iter()
477 .sorted_by_key(|w| w.id)
478 .collect_vec(),
479 worker_nodes.as_slice()[1..].to_vec()
480 );
481 }
482}