risingwave_batch/worker_manager/
worker_node_manager.rs1use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, RwLock, RwLockReadGuard};
17use std::time::Duration;
18
19use rand::seq::IndexedRandom;
20use risingwave_common::bail;
21use risingwave_common::catalog::OBJECT_ID_PLACEHOLDER;
22use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
23use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
24use risingwave_pb::common::{WorkerNode, WorkerType};
25
26use crate::error::{BatchError, Result};
27
28pub(crate) type FragmentId = u32;
29
30pub struct WorkerNodeManager {
32 inner: RwLock<WorkerNodeManagerInner>,
33 worker_node_mask: Arc<RwLock<HashSet<u32>>>,
35}
36
37struct WorkerNodeManagerInner {
38 worker_nodes: HashMap<u32, WorkerNode>,
39 streaming_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
41 serving_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
43}
44
45pub type WorkerNodeManagerRef = Arc<WorkerNodeManager>;
46
47impl Default for WorkerNodeManager {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl WorkerNodeManager {
54 pub fn new() -> Self {
55 Self {
56 inner: RwLock::new(WorkerNodeManagerInner {
57 worker_nodes: Default::default(),
58 streaming_fragment_vnode_mapping: Default::default(),
59 serving_fragment_vnode_mapping: Default::default(),
60 }),
61 worker_node_mask: Arc::new(Default::default()),
62 }
63 }
64
65 pub fn mock(worker_nodes: Vec<WorkerNode>) -> Self {
67 let worker_nodes = worker_nodes.into_iter().map(|w| (w.id, w)).collect();
68 let inner = RwLock::new(WorkerNodeManagerInner {
69 worker_nodes,
70 streaming_fragment_vnode_mapping: HashMap::new(),
71 serving_fragment_vnode_mapping: HashMap::new(),
72 });
73 Self {
74 inner,
75 worker_node_mask: Arc::new(Default::default()),
76 }
77 }
78
79 pub fn list_compute_nodes(&self) -> Vec<WorkerNode> {
80 self.inner
81 .read()
82 .unwrap()
83 .worker_nodes
84 .values()
85 .filter(|w| w.r#type() == WorkerType::ComputeNode)
86 .cloned()
87 .collect()
88 }
89
90 pub fn list_frontend_nodes(&self) -> Vec<WorkerNode> {
91 self.inner
92 .read()
93 .unwrap()
94 .worker_nodes
95 .values()
96 .filter(|w| w.r#type() == WorkerType::Frontend)
97 .cloned()
98 .collect()
99 }
100
101 fn list_serving_worker_nodes(&self) -> Vec<WorkerNode> {
102 self.list_compute_nodes()
103 .into_iter()
104 .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
105 .collect()
106 }
107
108 fn list_streaming_worker_nodes(&self) -> Vec<WorkerNode> {
109 self.list_compute_nodes()
110 .into_iter()
111 .filter(|w| w.property.as_ref().is_some_and(|p| p.is_streaming))
112 .collect()
113 }
114
115 pub fn add_worker_node(&self, node: WorkerNode) {
116 let mut write_guard = self.inner.write().unwrap();
117 write_guard.worker_nodes.insert(node.id, node);
118 }
119
120 pub fn remove_worker_node(&self, node: WorkerNode) {
121 let mut write_guard = self.inner.write().unwrap();
122 write_guard.worker_nodes.remove(&node.id);
123 }
124
125 pub fn refresh(
126 &self,
127 nodes: Vec<WorkerNode>,
128 streaming_mapping: HashMap<FragmentId, WorkerSlotMapping>,
129 serving_mapping: HashMap<FragmentId, WorkerSlotMapping>,
130 ) {
131 let mut write_guard = self.inner.write().unwrap();
132 tracing::debug!("Refresh worker nodes {:?}.", nodes);
133 tracing::debug!(
134 "Refresh streaming vnode mapping for fragments {:?}.",
135 streaming_mapping.keys()
136 );
137 tracing::debug!(
138 "Refresh serving vnode mapping for fragments {:?}.",
139 serving_mapping.keys()
140 );
141 write_guard.worker_nodes = nodes.into_iter().map(|w| (w.id, w)).collect();
142 write_guard.streaming_fragment_vnode_mapping = streaming_mapping;
143 write_guard.serving_fragment_vnode_mapping = serving_mapping;
144 }
145
146 pub fn get_workers_by_worker_slot_ids(
150 &self,
151 worker_slot_ids: &[WorkerSlotId],
152 ) -> Result<Vec<WorkerNode>> {
153 if worker_slot_ids.is_empty() {
154 return Err(BatchError::EmptyWorkerNodes);
155 }
156 let guard = self.inner.read().unwrap();
157 let mut workers = Vec::with_capacity(worker_slot_ids.len());
158 for worker_slot_id in worker_slot_ids {
159 match guard.worker_nodes.get(&worker_slot_id.worker_id()) {
160 Some(worker) => workers.push((*worker).clone()),
161 None => bail!(
162 "No worker node found for worker slot id: {}",
163 worker_slot_id
164 ),
165 }
166 }
167
168 Ok(workers)
169 }
170
171 pub fn get_streaming_fragment_mapping(
172 &self,
173 fragment_id: &FragmentId,
174 ) -> Result<WorkerSlotMapping> {
175 self.inner
176 .read()
177 .unwrap()
178 .streaming_fragment_vnode_mapping
179 .get(fragment_id)
180 .cloned()
181 .ok_or_else(|| BatchError::StreamingVnodeMappingNotFound(*fragment_id))
182 }
183
184 pub fn insert_streaming_fragment_mapping(
185 &self,
186 fragment_id: FragmentId,
187 vnode_mapping: WorkerSlotMapping,
188 ) {
189 if self
190 .inner
191 .write()
192 .unwrap()
193 .streaming_fragment_vnode_mapping
194 .try_insert(fragment_id, vnode_mapping)
195 .is_err()
196 {
197 tracing::info!(
198 "Previous batch vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
199 );
200 }
201 }
202
203 pub fn update_streaming_fragment_mapping(
204 &self,
205 fragment_id: FragmentId,
206 vnode_mapping: WorkerSlotMapping,
207 ) {
208 let mut guard = self.inner.write().unwrap();
209 if guard
210 .streaming_fragment_vnode_mapping
211 .insert(fragment_id, vnode_mapping)
212 .is_none()
213 {
214 tracing::info!(
215 "Previous vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
216 );
217 }
218 }
219
220 pub fn remove_streaming_fragment_mapping(&self, fragment_id: &FragmentId) {
221 let mut guard = self.inner.write().unwrap();
222
223 let res = guard.streaming_fragment_vnode_mapping.remove(fragment_id);
224 match &res {
225 Some(_) => {}
226 None if OBJECT_ID_PLACEHOLDER == *fragment_id => {
227 }
229 None => {
230 tracing::warn!(fragment_id, "Streaming vnode mapping not found");
231 }
232 };
233 }
234
235 fn serving_fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
237 self.inner
238 .read()
239 .unwrap()
240 .get_serving_fragment_mapping(fragment_id)
241 .ok_or_else(|| BatchError::ServingVnodeMappingNotFound(fragment_id))
242 }
243
244 pub fn set_serving_fragment_mapping(&self, mappings: HashMap<FragmentId, WorkerSlotMapping>) {
245 let mut guard = self.inner.write().unwrap();
246 tracing::debug!(
247 "Set serving vnode mapping for fragments {:?}",
248 mappings.keys()
249 );
250 guard.serving_fragment_vnode_mapping = mappings;
251 }
252
253 pub fn upsert_serving_fragment_mapping(
254 &self,
255 mappings: HashMap<FragmentId, WorkerSlotMapping>,
256 ) {
257 let mut guard = self.inner.write().unwrap();
258 tracing::debug!(
259 "Upsert serving vnode mapping for fragments {:?}",
260 mappings.keys()
261 );
262 for (fragment_id, mapping) in mappings {
263 guard
264 .serving_fragment_vnode_mapping
265 .insert(fragment_id, mapping);
266 }
267 }
268
269 pub fn remove_serving_fragment_mapping(&self, fragment_ids: &[FragmentId]) {
270 let mut guard = self.inner.write().unwrap();
271 tracing::debug!(
272 "Delete serving vnode mapping for fragments {:?}",
273 fragment_ids
274 );
275 for fragment_id in fragment_ids {
276 guard.serving_fragment_vnode_mapping.remove(fragment_id);
277 }
278 }
279
280 fn worker_node_mask(&self) -> RwLockReadGuard<'_, HashSet<u32>> {
281 self.worker_node_mask.read().unwrap()
282 }
283
284 pub fn mask_worker_node(&self, worker_node_id: u32, duration: Duration) {
285 tracing::info!(
286 "Mask worker node {} for {:?} temporarily",
287 worker_node_id,
288 duration
289 );
290 let mut worker_node_mask = self.worker_node_mask.write().unwrap();
291 if worker_node_mask.contains(&worker_node_id) {
292 return;
293 }
294 worker_node_mask.insert(worker_node_id);
295 let worker_node_mask_ref = self.worker_node_mask.clone();
296 tokio::spawn(async move {
297 tokio::time::sleep(duration).await;
298 worker_node_mask_ref
299 .write()
300 .unwrap()
301 .remove(&worker_node_id);
302 });
303 }
304
305 pub fn worker_node(&self, worker_id: u32) -> Option<WorkerNode> {
306 self.inner.read().unwrap().worker_node(worker_id)
307 }
308}
309
310impl WorkerNodeManagerInner {
311 fn get_serving_fragment_mapping(&self, fragment_id: FragmentId) -> Option<WorkerSlotMapping> {
312 self.serving_fragment_vnode_mapping
313 .get(&fragment_id)
314 .cloned()
315 }
316
317 fn worker_node(&self, worker_id: u32) -> Option<WorkerNode> {
318 self.worker_nodes.get(&worker_id).cloned()
319 }
320}
321
322#[derive(Clone)]
324pub struct WorkerNodeSelector {
325 pub manager: WorkerNodeManagerRef,
326 enable_barrier_read: bool,
327}
328
329impl WorkerNodeSelector {
330 pub fn new(manager: WorkerNodeManagerRef, enable_barrier_read: bool) -> Self {
331 Self {
332 manager,
333 enable_barrier_read,
334 }
335 }
336
337 pub fn worker_node_count(&self) -> usize {
338 if self.enable_barrier_read {
339 self.manager.list_streaming_worker_nodes().len()
340 } else {
341 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
342 .len()
343 }
344 }
345
346 pub fn schedule_unit_count(&self) -> usize {
347 let worker_nodes = if self.enable_barrier_read {
348 self.manager.list_streaming_worker_nodes()
349 } else {
350 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
351 };
352 worker_nodes
353 .iter()
354 .map(|node| node.compute_node_parallelism())
355 .sum()
356 }
357
358 pub fn fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
359 if self.enable_barrier_read {
360 self.manager.get_streaming_fragment_mapping(&fragment_id)
361 } else {
362 let mapping = (self.manager.serving_fragment_mapping(fragment_id)).or_else(|_| {
363 tracing::warn!(
364 fragment_id,
365 "Serving fragment mapping not found, fall back to streaming one."
366 );
367 self.manager.get_streaming_fragment_mapping(&fragment_id)
368 })?;
369
370 if self.manager.worker_node_mask().is_empty() {
372 Ok(mapping)
373 } else {
374 let workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
375 let max_parallelism = mapping.to_single().map(|_| 1);
377 let masked_mapping =
379 place_vnode(Some(&mapping), &workers, max_parallelism, mapping.len())
380 .ok_or_else(|| BatchError::EmptyWorkerNodes)?;
381 Ok(masked_mapping)
382 }
383 }
384 }
385
386 pub fn next_random_worker(&self) -> Result<WorkerNode> {
387 let worker_nodes = if self.enable_barrier_read {
388 self.manager.list_streaming_worker_nodes()
389 } else {
390 self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
391 };
392 worker_nodes
393 .choose(&mut rand::rng())
394 .ok_or_else(|| BatchError::EmptyWorkerNodes)
395 .map(|w| (*w).clone())
396 }
397
398 fn apply_worker_node_mask(&self, origin: Vec<WorkerNode>) -> Vec<WorkerNode> {
399 let mask = self.manager.worker_node_mask();
400 if origin.iter().all(|w| mask.contains(&w.id)) {
401 return origin;
402 }
403 origin
404 .into_iter()
405 .filter(|w| !mask.contains(&w.id))
406 .collect()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use itertools::Itertools;
413 use risingwave_common::util::addr::HostAddr;
414 use risingwave_pb::common::worker_node;
415 use risingwave_pb::common::worker_node::Property;
416
417 #[test]
418 fn test_worker_node_manager() {
419 use super::*;
420
421 let manager = WorkerNodeManager::mock(vec![]);
422 assert_eq!(manager.list_serving_worker_nodes().len(), 0);
423 assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
424 assert_eq!(manager.list_compute_nodes(), vec![]);
425
426 let worker_nodes = vec![
427 WorkerNode {
428 id: 1,
429 r#type: WorkerType::ComputeNode as i32,
430 host: Some(HostAddr::try_from("127.0.0.1:1234").unwrap().to_protobuf()),
431 state: worker_node::State::Running as i32,
432 property: Some(Property {
433 is_unschedulable: false,
434 is_serving: true,
435 is_streaming: true,
436 ..Default::default()
437 }),
438 transactional_id: Some(1),
439 ..Default::default()
440 },
441 WorkerNode {
442 id: 2,
443 r#type: WorkerType::ComputeNode as i32,
444 host: Some(HostAddr::try_from("127.0.0.1:1235").unwrap().to_protobuf()),
445 state: worker_node::State::Running as i32,
446 property: Some(Property {
447 is_unschedulable: false,
448 is_serving: true,
449 is_streaming: false,
450 ..Default::default()
451 }),
452 transactional_id: Some(2),
453 ..Default::default()
454 },
455 ];
456 worker_nodes
457 .iter()
458 .for_each(|w| manager.add_worker_node(w.clone()));
459 assert_eq!(manager.list_serving_worker_nodes().len(), 2);
460 assert_eq!(manager.list_streaming_worker_nodes().len(), 1);
461 assert_eq!(
462 manager
463 .list_compute_nodes()
464 .into_iter()
465 .sorted_by_key(|w| w.id)
466 .collect_vec(),
467 worker_nodes
468 );
469
470 manager.remove_worker_node(worker_nodes[0].clone());
471 assert_eq!(manager.list_serving_worker_nodes().len(), 1);
472 assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
473 assert_eq!(
474 manager
475 .list_compute_nodes()
476 .into_iter()
477 .sorted_by_key(|w| w.id)
478 .collect_vec(),
479 worker_nodes.as_slice()[1..].to_vec()
480 );
481 }
482}