risingwave_batch/worker_manager/
worker_node_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, RwLock, RwLockReadGuard};
17use std::time::Duration;
18
19use rand::seq::IndexedRandom;
20use risingwave_common::bail;
21use risingwave_common::catalog::OBJECT_ID_PLACEHOLDER;
22use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
23use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
24use risingwave_pb::common::{WorkerNode, WorkerType};
25
26use crate::error::{BatchError, Result};
27
28pub(crate) type FragmentId = u32;
29
30/// `WorkerNodeManager` manages live worker nodes and table vnode mapping information.
31pub struct WorkerNodeManager {
32    inner: RwLock<WorkerNodeManagerInner>,
33    /// Temporarily make worker invisible from serving cluster.
34    worker_node_mask: Arc<RwLock<HashSet<u32>>>,
35}
36
37struct WorkerNodeManagerInner {
38    worker_nodes: HashMap<u32, WorkerNode>,
39    /// fragment vnode mapping info for streaming
40    streaming_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
41    /// fragment vnode mapping info for serving
42    serving_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
43}
44
45pub type WorkerNodeManagerRef = Arc<WorkerNodeManager>;
46
47impl Default for WorkerNodeManager {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl WorkerNodeManager {
54    pub fn new() -> Self {
55        Self {
56            inner: RwLock::new(WorkerNodeManagerInner {
57                worker_nodes: Default::default(),
58                streaming_fragment_vnode_mapping: Default::default(),
59                serving_fragment_vnode_mapping: Default::default(),
60            }),
61            worker_node_mask: Arc::new(Default::default()),
62        }
63    }
64
65    /// Used in tests.
66    pub fn mock(worker_nodes: Vec<WorkerNode>) -> Self {
67        let worker_nodes = worker_nodes.into_iter().map(|w| (w.id, w)).collect();
68        let inner = RwLock::new(WorkerNodeManagerInner {
69            worker_nodes,
70            streaming_fragment_vnode_mapping: HashMap::new(),
71            serving_fragment_vnode_mapping: HashMap::new(),
72        });
73        Self {
74            inner,
75            worker_node_mask: Arc::new(Default::default()),
76        }
77    }
78
79    pub fn list_compute_nodes(&self) -> Vec<WorkerNode> {
80        self.inner
81            .read()
82            .unwrap()
83            .worker_nodes
84            .values()
85            .filter(|w| w.r#type() == WorkerType::ComputeNode)
86            .cloned()
87            .collect()
88    }
89
90    pub fn list_frontend_nodes(&self) -> Vec<WorkerNode> {
91        self.inner
92            .read()
93            .unwrap()
94            .worker_nodes
95            .values()
96            .filter(|w| w.r#type() == WorkerType::Frontend)
97            .cloned()
98            .collect()
99    }
100
101    fn list_serving_worker_nodes(&self) -> Vec<WorkerNode> {
102        self.list_compute_nodes()
103            .into_iter()
104            .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
105            .collect()
106    }
107
108    fn list_streaming_worker_nodes(&self) -> Vec<WorkerNode> {
109        self.list_compute_nodes()
110            .into_iter()
111            .filter(|w| w.property.as_ref().is_some_and(|p| p.is_streaming))
112            .collect()
113    }
114
115    pub fn add_worker_node(&self, node: WorkerNode) {
116        let mut write_guard = self.inner.write().unwrap();
117        write_guard.worker_nodes.insert(node.id, node);
118    }
119
120    pub fn remove_worker_node(&self, node: WorkerNode) {
121        let mut write_guard = self.inner.write().unwrap();
122        write_guard.worker_nodes.remove(&node.id);
123    }
124
125    pub fn refresh(
126        &self,
127        nodes: Vec<WorkerNode>,
128        streaming_mapping: HashMap<FragmentId, WorkerSlotMapping>,
129        serving_mapping: HashMap<FragmentId, WorkerSlotMapping>,
130    ) {
131        let mut write_guard = self.inner.write().unwrap();
132        tracing::debug!("Refresh worker nodes {:?}.", nodes);
133        tracing::debug!(
134            "Refresh streaming vnode mapping for fragments {:?}.",
135            streaming_mapping.keys()
136        );
137        tracing::debug!(
138            "Refresh serving vnode mapping for fragments {:?}.",
139            serving_mapping.keys()
140        );
141        write_guard.worker_nodes = nodes.into_iter().map(|w| (w.id, w)).collect();
142        write_guard.streaming_fragment_vnode_mapping = streaming_mapping;
143        write_guard.serving_fragment_vnode_mapping = serving_mapping;
144    }
145
146    /// If worker slot ids is empty, the scheduler may fail to schedule any task and stuck at
147    /// schedule next stage. If we do not return error in this case, needs more complex control
148    /// logic above. Report in this function makes the schedule root fail reason more clear.
149    pub fn get_workers_by_worker_slot_ids(
150        &self,
151        worker_slot_ids: &[WorkerSlotId],
152    ) -> Result<Vec<WorkerNode>> {
153        if worker_slot_ids.is_empty() {
154            return Err(BatchError::EmptyWorkerNodes);
155        }
156        let guard = self.inner.read().unwrap();
157        let mut workers = Vec::with_capacity(worker_slot_ids.len());
158        for worker_slot_id in worker_slot_ids {
159            match guard.worker_nodes.get(&worker_slot_id.worker_id()) {
160                Some(worker) => workers.push((*worker).clone()),
161                None => bail!(
162                    "No worker node found for worker slot id: {}",
163                    worker_slot_id
164                ),
165            }
166        }
167
168        Ok(workers)
169    }
170
171    pub fn get_streaming_fragment_mapping(
172        &self,
173        fragment_id: &FragmentId,
174    ) -> Result<WorkerSlotMapping> {
175        self.inner
176            .read()
177            .unwrap()
178            .streaming_fragment_vnode_mapping
179            .get(fragment_id)
180            .cloned()
181            .ok_or_else(|| BatchError::StreamingVnodeMappingNotFound(*fragment_id))
182    }
183
184    pub fn insert_streaming_fragment_mapping(
185        &self,
186        fragment_id: FragmentId,
187        vnode_mapping: WorkerSlotMapping,
188    ) {
189        if self
190            .inner
191            .write()
192            .unwrap()
193            .streaming_fragment_vnode_mapping
194            .try_insert(fragment_id, vnode_mapping)
195            .is_err()
196        {
197            tracing::info!(
198                "Previous batch vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
199            );
200        }
201    }
202
203    pub fn update_streaming_fragment_mapping(
204        &self,
205        fragment_id: FragmentId,
206        vnode_mapping: WorkerSlotMapping,
207    ) {
208        let mut guard = self.inner.write().unwrap();
209        if guard
210            .streaming_fragment_vnode_mapping
211            .insert(fragment_id, vnode_mapping)
212            .is_none()
213        {
214            tracing::info!(
215                "Previous vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
216            );
217        }
218    }
219
220    pub fn remove_streaming_fragment_mapping(&self, fragment_id: &FragmentId) {
221        let mut guard = self.inner.write().unwrap();
222
223        let res = guard.streaming_fragment_vnode_mapping.remove(fragment_id);
224        match &res {
225            Some(_) => {}
226            None if OBJECT_ID_PLACEHOLDER == *fragment_id => {
227                // Do nothing for placeholder fragment.
228            }
229            None => {
230                tracing::warn!(fragment_id, "Streaming vnode mapping not found");
231            }
232        };
233    }
234
235    /// Returns fragment's vnode mapping for serving.
236    fn serving_fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
237        self.inner
238            .read()
239            .unwrap()
240            .get_serving_fragment_mapping(fragment_id)
241            .ok_or_else(|| BatchError::ServingVnodeMappingNotFound(fragment_id))
242    }
243
244    pub fn set_serving_fragment_mapping(&self, mappings: HashMap<FragmentId, WorkerSlotMapping>) {
245        let mut guard = self.inner.write().unwrap();
246        tracing::debug!(
247            "Set serving vnode mapping for fragments {:?}",
248            mappings.keys()
249        );
250        guard.serving_fragment_vnode_mapping = mappings;
251    }
252
253    pub fn upsert_serving_fragment_mapping(
254        &self,
255        mappings: HashMap<FragmentId, WorkerSlotMapping>,
256    ) {
257        let mut guard = self.inner.write().unwrap();
258        tracing::debug!(
259            "Upsert serving vnode mapping for fragments {:?}",
260            mappings.keys()
261        );
262        for (fragment_id, mapping) in mappings {
263            guard
264                .serving_fragment_vnode_mapping
265                .insert(fragment_id, mapping);
266        }
267    }
268
269    pub fn remove_serving_fragment_mapping(&self, fragment_ids: &[FragmentId]) {
270        let mut guard = self.inner.write().unwrap();
271        tracing::debug!(
272            "Delete serving vnode mapping for fragments {:?}",
273            fragment_ids
274        );
275        for fragment_id in fragment_ids {
276            guard.serving_fragment_vnode_mapping.remove(fragment_id);
277        }
278    }
279
280    fn worker_node_mask(&self) -> RwLockReadGuard<'_, HashSet<u32>> {
281        self.worker_node_mask.read().unwrap()
282    }
283
284    pub fn mask_worker_node(&self, worker_node_id: u32, duration: Duration) {
285        tracing::info!(
286            "Mask worker node {} for {:?} temporarily",
287            worker_node_id,
288            duration
289        );
290        let mut worker_node_mask = self.worker_node_mask.write().unwrap();
291        if worker_node_mask.contains(&worker_node_id) {
292            return;
293        }
294        worker_node_mask.insert(worker_node_id);
295        let worker_node_mask_ref = self.worker_node_mask.clone();
296        tokio::spawn(async move {
297            tokio::time::sleep(duration).await;
298            worker_node_mask_ref
299                .write()
300                .unwrap()
301                .remove(&worker_node_id);
302        });
303    }
304
305    pub fn worker_node(&self, worker_id: u32) -> Option<WorkerNode> {
306        self.inner.read().unwrap().worker_node(worker_id)
307    }
308}
309
310impl WorkerNodeManagerInner {
311    fn get_serving_fragment_mapping(&self, fragment_id: FragmentId) -> Option<WorkerSlotMapping> {
312        self.serving_fragment_vnode_mapping
313            .get(&fragment_id)
314            .cloned()
315    }
316
317    fn worker_node(&self, worker_id: u32) -> Option<WorkerNode> {
318        self.worker_nodes.get(&worker_id).cloned()
319    }
320}
321
322/// Selects workers for query according to `enable_barrier_read`
323#[derive(Clone)]
324pub struct WorkerNodeSelector {
325    pub manager: WorkerNodeManagerRef,
326    enable_barrier_read: bool,
327}
328
329impl WorkerNodeSelector {
330    pub fn new(manager: WorkerNodeManagerRef, enable_barrier_read: bool) -> Self {
331        Self {
332            manager,
333            enable_barrier_read,
334        }
335    }
336
337    pub fn worker_node_count(&self) -> usize {
338        if self.enable_barrier_read {
339            self.manager.list_streaming_worker_nodes().len()
340        } else {
341            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
342                .len()
343        }
344    }
345
346    pub fn schedule_unit_count(&self) -> usize {
347        let worker_nodes = if self.enable_barrier_read {
348            self.manager.list_streaming_worker_nodes()
349        } else {
350            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
351        };
352        worker_nodes
353            .iter()
354            .map(|node| node.compute_node_parallelism())
355            .sum()
356    }
357
358    pub fn fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
359        if self.enable_barrier_read {
360            self.manager.get_streaming_fragment_mapping(&fragment_id)
361        } else {
362            let mapping = (self.manager.serving_fragment_mapping(fragment_id)).or_else(|_| {
363                tracing::warn!(
364                    fragment_id,
365                    "Serving fragment mapping not found, fall back to streaming one."
366                );
367                self.manager.get_streaming_fragment_mapping(&fragment_id)
368            })?;
369
370            // Filter out unavailable workers.
371            if self.manager.worker_node_mask().is_empty() {
372                Ok(mapping)
373            } else {
374                let workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
375                // If it's a singleton, set max_parallelism=1 for place_vnode.
376                let max_parallelism = mapping.to_single().map(|_| 1);
377                // TODO: use runtime parameter batch_parallelism
378                let masked_mapping =
379                    place_vnode(Some(&mapping), &workers, max_parallelism, mapping.len())
380                        .ok_or_else(|| BatchError::EmptyWorkerNodes)?;
381                Ok(masked_mapping)
382            }
383        }
384    }
385
386    pub fn next_random_worker(&self) -> Result<WorkerNode> {
387        let worker_nodes = if self.enable_barrier_read {
388            self.manager.list_streaming_worker_nodes()
389        } else {
390            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
391        };
392        worker_nodes
393            .choose(&mut rand::rng())
394            .ok_or_else(|| BatchError::EmptyWorkerNodes)
395            .map(|w| (*w).clone())
396    }
397
398    fn apply_worker_node_mask(&self, origin: Vec<WorkerNode>) -> Vec<WorkerNode> {
399        let mask = self.manager.worker_node_mask();
400        if origin.iter().all(|w| mask.contains(&w.id)) {
401            return origin;
402        }
403        origin
404            .into_iter()
405            .filter(|w| !mask.contains(&w.id))
406            .collect()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use itertools::Itertools;
413    use risingwave_common::util::addr::HostAddr;
414    use risingwave_pb::common::worker_node;
415    use risingwave_pb::common::worker_node::Property;
416
417    #[test]
418    fn test_worker_node_manager() {
419        use super::*;
420
421        let manager = WorkerNodeManager::mock(vec![]);
422        assert_eq!(manager.list_serving_worker_nodes().len(), 0);
423        assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
424        assert_eq!(manager.list_compute_nodes(), vec![]);
425
426        let worker_nodes = vec![
427            WorkerNode {
428                id: 1,
429                r#type: WorkerType::ComputeNode as i32,
430                host: Some(HostAddr::try_from("127.0.0.1:1234").unwrap().to_protobuf()),
431                state: worker_node::State::Running as i32,
432                property: Some(Property {
433                    is_unschedulable: false,
434                    is_serving: true,
435                    is_streaming: true,
436                    ..Default::default()
437                }),
438                transactional_id: Some(1),
439                ..Default::default()
440            },
441            WorkerNode {
442                id: 2,
443                r#type: WorkerType::ComputeNode as i32,
444                host: Some(HostAddr::try_from("127.0.0.1:1235").unwrap().to_protobuf()),
445                state: worker_node::State::Running as i32,
446                property: Some(Property {
447                    is_unschedulable: false,
448                    is_serving: true,
449                    is_streaming: false,
450                    ..Default::default()
451                }),
452                transactional_id: Some(2),
453                ..Default::default()
454            },
455        ];
456        worker_nodes
457            .iter()
458            .for_each(|w| manager.add_worker_node(w.clone()));
459        assert_eq!(manager.list_serving_worker_nodes().len(), 2);
460        assert_eq!(manager.list_streaming_worker_nodes().len(), 1);
461        assert_eq!(
462            manager
463                .list_compute_nodes()
464                .into_iter()
465                .sorted_by_key(|w| w.id)
466                .collect_vec(),
467            worker_nodes
468        );
469
470        manager.remove_worker_node(worker_nodes[0].clone());
471        assert_eq!(manager.list_serving_worker_nodes().len(), 1);
472        assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
473        assert_eq!(
474            manager
475                .list_compute_nodes()
476                .into_iter()
477                .sorted_by_key(|w| w.id)
478                .collect_vec(),
479            worker_nodes.as_slice()[1..].to_vec()
480        );
481    }
482}