risingwave_batch/worker_manager/
worker_node_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, RwLock, RwLockReadGuard};
17use std::time::Duration;
18
19use rand::seq::IndexedRandom;
20use risingwave_common::bail;
21use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
22use risingwave_common::id::{FragmentId, WorkerId};
23use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
24use risingwave_pb::common::{WorkerNode, WorkerType};
25
26use crate::error::{BatchError, Result};
27
28/// `WorkerNodeManager` manages live worker nodes and table vnode mapping information.
29pub struct WorkerNodeManager {
30    inner: RwLock<WorkerNodeManagerInner>,
31    /// Temporarily make worker invisible from serving cluster.
32    worker_node_mask: Arc<RwLock<HashSet<WorkerId>>>,
33}
34
35struct WorkerNodeManagerInner {
36    worker_nodes: HashMap<WorkerId, WorkerNode>,
37    /// fragment vnode mapping info for streaming
38    streaming_fragment_vnode_mapping: Option<HashMap<FragmentId, WorkerSlotMapping>>,
39    /// fragment vnode mapping info for serving
40    serving_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
41}
42
43pub type WorkerNodeManagerRef = Arc<WorkerNodeManager>;
44
45impl Default for WorkerNodeManager {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl WorkerNodeManager {
52    pub fn new() -> Self {
53        Self {
54            inner: RwLock::new(WorkerNodeManagerInner {
55                worker_nodes: Default::default(),
56                streaming_fragment_vnode_mapping: None,
57                serving_fragment_vnode_mapping: Default::default(),
58            }),
59            worker_node_mask: Arc::new(Default::default()),
60        }
61    }
62
63    /// Used in tests.
64    pub fn mock(worker_nodes: Vec<WorkerNode>) -> Self {
65        let worker_nodes = worker_nodes.into_iter().map(|w| (w.id, w)).collect();
66        let inner = RwLock::new(WorkerNodeManagerInner {
67            worker_nodes,
68            streaming_fragment_vnode_mapping: None,
69            serving_fragment_vnode_mapping: HashMap::new(),
70        });
71        Self {
72            inner,
73            worker_node_mask: Arc::new(Default::default()),
74        }
75    }
76
77    pub fn list_compute_nodes(&self) -> Vec<WorkerNode> {
78        self.inner
79            .read()
80            .unwrap()
81            .worker_nodes
82            .values()
83            .filter(|w| w.r#type() == WorkerType::ComputeNode)
84            .cloned()
85            .collect()
86    }
87
88    pub fn list_frontend_nodes(&self) -> Vec<WorkerNode> {
89        self.inner
90            .read()
91            .unwrap()
92            .worker_nodes
93            .values()
94            .filter(|w| w.r#type() == WorkerType::Frontend)
95            .cloned()
96            .collect()
97    }
98
99    fn list_serving_worker_nodes(&self) -> Vec<WorkerNode> {
100        self.list_compute_nodes()
101            .into_iter()
102            .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
103            .collect()
104    }
105
106    fn list_streaming_worker_nodes(&self) -> Vec<WorkerNode> {
107        self.list_compute_nodes()
108            .into_iter()
109            .filter(|w| w.property.as_ref().is_some_and(|p| p.is_streaming))
110            .collect()
111    }
112
113    pub fn add_worker_node(&self, node: WorkerNode) {
114        let mut write_guard = self.inner.write().unwrap();
115        write_guard.worker_nodes.insert(node.id, node);
116    }
117
118    pub fn remove_worker_node(&self, node: WorkerNode) {
119        let mut write_guard = self.inner.write().unwrap();
120        write_guard.worker_nodes.remove(&node.id);
121    }
122
123    pub fn refresh(
124        &self,
125        nodes: Vec<WorkerNode>,
126        streaming_mapping: HashMap<FragmentId, WorkerSlotMapping>,
127        serving_mapping: HashMap<FragmentId, WorkerSlotMapping>,
128    ) {
129        let mut write_guard = self.inner.write().unwrap();
130        tracing::debug!("Refresh worker nodes {:?}.", nodes);
131        tracing::debug!(
132            "Refresh streaming vnode mapping for fragments {:?}.",
133            streaming_mapping.keys()
134        );
135        tracing::debug!(
136            "Refresh serving vnode mapping for fragments {:?}.",
137            serving_mapping.keys()
138        );
139        write_guard.worker_nodes = nodes.into_iter().map(|w| (w.id, w)).collect();
140        write_guard.streaming_fragment_vnode_mapping = Some(streaming_mapping);
141        write_guard.serving_fragment_vnode_mapping = serving_mapping;
142    }
143
144    /// If worker slot ids is empty, the scheduler may fail to schedule any task and stuck at
145    /// schedule next stage. If we do not return error in this case, needs more complex control
146    /// logic above. Report in this function makes the schedule root fail reason more clear.
147    pub fn get_workers_by_worker_slot_ids(
148        &self,
149        worker_slot_ids: &[WorkerSlotId],
150    ) -> Result<Vec<WorkerNode>> {
151        if worker_slot_ids.is_empty() {
152            return Err(BatchError::EmptyWorkerNodes);
153        }
154        let guard = self.inner.read().unwrap();
155        let mut workers = Vec::with_capacity(worker_slot_ids.len());
156        for worker_slot_id in worker_slot_ids {
157            match guard.worker_nodes.get(&worker_slot_id.worker_id()) {
158                Some(worker) => workers.push((*worker).clone()),
159                None => bail!(
160                    "No worker node found for worker slot id: {}",
161                    worker_slot_id
162                ),
163            }
164        }
165
166        Ok(workers)
167    }
168
169    pub fn get_streaming_fragment_mapping(
170        &self,
171        fragment_id: &FragmentId,
172    ) -> Result<WorkerSlotMapping> {
173        let guard = self.inner.read().unwrap();
174
175        let Some(streaming_mapping) = guard.streaming_fragment_vnode_mapping.as_ref() else {
176            return Err(BatchError::StreamingVnodeMappingNotInitialized);
177        };
178
179        streaming_mapping
180            .get(fragment_id)
181            .cloned()
182            .ok_or_else(|| BatchError::StreamingVnodeMappingNotFound(*fragment_id))
183    }
184
185    pub fn insert_streaming_fragment_mapping(
186        &self,
187        fragment_id: FragmentId,
188        vnode_mapping: WorkerSlotMapping,
189    ) {
190        let mut guard = self.inner.write().unwrap();
191        let mapping = guard
192            .streaming_fragment_vnode_mapping
193            .get_or_insert_with(HashMap::new);
194        if mapping.try_insert(fragment_id, vnode_mapping).is_err() {
195            tracing::info!(
196                "Previous batch vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
197            );
198        }
199    }
200
201    pub fn update_streaming_fragment_mapping(
202        &self,
203        fragment_id: FragmentId,
204        vnode_mapping: WorkerSlotMapping,
205    ) {
206        let mut guard = self.inner.write().unwrap();
207        let mapping = guard
208            .streaming_fragment_vnode_mapping
209            .get_or_insert_with(HashMap::new);
210        if mapping.insert(fragment_id, vnode_mapping).is_none() {
211            tracing::info!(
212                "Previous vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
213            );
214        }
215    }
216
217    pub fn remove_streaming_fragment_mapping(&self, fragment_id: &FragmentId) {
218        let mut guard = self.inner.write().unwrap();
219
220        let res = guard
221            .streaming_fragment_vnode_mapping
222            .as_mut()
223            .and_then(|mapping| mapping.remove(fragment_id));
224        match &res {
225            Some(_) => {}
226            None if fragment_id.is_placeholder() => {
227                // Do nothing for placeholder fragment.
228            }
229            None => {
230                tracing::warn!(%fragment_id, "Streaming vnode mapping not found");
231            }
232        };
233    }
234
235    /// Returns fragment's vnode mapping for serving.
236    fn serving_fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
237        self.inner
238            .read()
239            .unwrap()
240            .get_serving_fragment_mapping(fragment_id)
241            .ok_or_else(|| BatchError::ServingVnodeMappingNotFound(fragment_id))
242    }
243
244    pub fn set_serving_fragment_mapping(&self, mappings: HashMap<FragmentId, WorkerSlotMapping>) {
245        let mut guard = self.inner.write().unwrap();
246        tracing::debug!(
247            "Set serving vnode mapping for fragments {:?}",
248            mappings.keys()
249        );
250        guard.serving_fragment_vnode_mapping = mappings;
251    }
252
253    pub fn upsert_serving_fragment_mapping(
254        &self,
255        mappings: HashMap<FragmentId, WorkerSlotMapping>,
256    ) {
257        let mut guard = self.inner.write().unwrap();
258        tracing::debug!(
259            "Upsert serving vnode mapping for fragments {:?}",
260            mappings.keys()
261        );
262        for (fragment_id, mapping) in mappings {
263            guard
264                .serving_fragment_vnode_mapping
265                .insert(fragment_id, mapping);
266        }
267    }
268
269    pub fn remove_serving_fragment_mapping(&self, fragment_ids: &[FragmentId]) {
270        let mut guard = self.inner.write().unwrap();
271        tracing::debug!(
272            "Delete serving vnode mapping for fragments {:?}",
273            fragment_ids
274        );
275        for fragment_id in fragment_ids {
276            guard.serving_fragment_vnode_mapping.remove(fragment_id);
277        }
278    }
279
280    fn worker_node_mask(&self) -> RwLockReadGuard<'_, HashSet<WorkerId>> {
281        self.worker_node_mask.read().unwrap()
282    }
283
284    pub fn mask_worker_node(&self, worker_node_id: WorkerId, duration: Duration) {
285        tracing::info!(
286            "Mask worker node {} for {:?} temporarily",
287            worker_node_id,
288            duration
289        );
290        let mut worker_node_mask = self.worker_node_mask.write().unwrap();
291        if worker_node_mask.contains(&worker_node_id) {
292            return;
293        }
294        worker_node_mask.insert(worker_node_id);
295        let worker_node_mask_ref = self.worker_node_mask.clone();
296        tokio::spawn(async move {
297            tokio::time::sleep(duration).await;
298            worker_node_mask_ref
299                .write()
300                .unwrap()
301                .remove(&worker_node_id);
302        });
303    }
304
305    pub fn worker_node(&self, worker_id: WorkerId) -> Option<WorkerNode> {
306        self.inner.read().unwrap().worker_node(worker_id)
307    }
308}
309
310impl WorkerNodeManagerInner {
311    fn get_serving_fragment_mapping(&self, fragment_id: FragmentId) -> Option<WorkerSlotMapping> {
312        self.serving_fragment_vnode_mapping
313            .get(&fragment_id)
314            .cloned()
315    }
316
317    fn worker_node(&self, worker_id: WorkerId) -> Option<WorkerNode> {
318        self.worker_nodes.get(&worker_id).cloned()
319    }
320}
321
322/// Selects workers for query according to `enable_barrier_read`
323#[derive(Clone)]
324pub struct WorkerNodeSelector {
325    pub manager: WorkerNodeManagerRef,
326    enable_barrier_read: bool,
327}
328
329impl WorkerNodeSelector {
330    pub fn new(manager: WorkerNodeManagerRef, enable_barrier_read: bool) -> Self {
331        Self {
332            manager,
333            enable_barrier_read,
334        }
335    }
336
337    pub fn worker_node_count(&self) -> usize {
338        if self.enable_barrier_read {
339            self.manager.list_streaming_worker_nodes().len()
340        } else {
341            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
342                .len()
343        }
344    }
345
346    pub fn schedule_unit_count(&self) -> usize {
347        let worker_nodes = if self.enable_barrier_read {
348            self.manager.list_streaming_worker_nodes()
349        } else {
350            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
351        };
352        worker_nodes
353            .iter()
354            .map(|node| node.compute_node_parallelism())
355            .sum()
356    }
357
358    pub fn fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
359        if self.enable_barrier_read {
360            self.manager.get_streaming_fragment_mapping(&fragment_id)
361        } else {
362            let mapping = (self.manager.serving_fragment_mapping(fragment_id)).or_else(|_| {
363                tracing::warn!(
364                    %fragment_id,
365                    "Serving fragment mapping not found, fall back to streaming one."
366                );
367                self.manager.get_streaming_fragment_mapping(&fragment_id)
368            })?;
369
370            // Filter out unavailable workers.
371            if self.manager.worker_node_mask().is_empty() {
372                Ok(mapping)
373            } else {
374                let workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
375                // If it's a singleton, set max_parallelism=1 for place_vnode.
376                let max_parallelism = mapping.to_single().map(|_| 1);
377                // TODO: use runtime parameter batch_parallelism
378                let masked_mapping =
379                    place_vnode(Some(&mapping), &workers, max_parallelism, mapping.len())
380                        .ok_or_else(|| BatchError::EmptyWorkerNodes)?;
381                Ok(masked_mapping)
382            }
383        }
384    }
385
386    pub fn next_random_worker(&self) -> Result<WorkerNode> {
387        let worker_nodes = if self.enable_barrier_read {
388            self.manager.list_streaming_worker_nodes()
389        } else {
390            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
391        };
392        worker_nodes
393            .choose(&mut rand::rng())
394            .ok_or_else(|| BatchError::EmptyWorkerNodes)
395            .map(|w| (*w).clone())
396    }
397
398    fn apply_worker_node_mask(&self, origin: Vec<WorkerNode>) -> Vec<WorkerNode> {
399        let mask = self.manager.worker_node_mask();
400        if origin.iter().all(|w| mask.contains(&w.id)) {
401            return origin;
402        }
403        origin
404            .into_iter()
405            .filter(|w| !mask.contains(&w.id))
406            .collect()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use itertools::Itertools;
413    use risingwave_common::util::addr::HostAddr;
414    use risingwave_pb::common::worker_node;
415    use risingwave_pb::common::worker_node::Property;
416
417    #[test]
418    fn test_worker_node_manager() {
419        use super::*;
420
421        let manager = WorkerNodeManager::mock(vec![]);
422        assert_eq!(manager.list_serving_worker_nodes().len(), 0);
423        assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
424        assert_eq!(manager.list_compute_nodes(), vec![]);
425
426        let worker_nodes = vec![
427            WorkerNode {
428                id: 1.into(),
429                r#type: WorkerType::ComputeNode as i32,
430                host: Some(HostAddr::try_from("127.0.0.1:1234").unwrap().to_protobuf()),
431                state: worker_node::State::Running as i32,
432                property: Some(Property {
433                    is_unschedulable: false,
434                    is_serving: true,
435                    is_streaming: true,
436                    ..Default::default()
437                }),
438                transactional_id: Some(1),
439                ..Default::default()
440            },
441            WorkerNode {
442                id: 2.into(),
443                r#type: WorkerType::ComputeNode as i32,
444                host: Some(HostAddr::try_from("127.0.0.1:1235").unwrap().to_protobuf()),
445                state: worker_node::State::Running as i32,
446                property: Some(Property {
447                    is_unschedulable: false,
448                    is_serving: true,
449                    is_streaming: false,
450                    ..Default::default()
451                }),
452                transactional_id: Some(2),
453                ..Default::default()
454            },
455        ];
456        worker_nodes
457            .iter()
458            .for_each(|w| manager.add_worker_node(w.clone()));
459        assert_eq!(manager.list_serving_worker_nodes().len(), 2);
460        assert_eq!(manager.list_streaming_worker_nodes().len(), 1);
461        assert_eq!(
462            manager
463                .list_compute_nodes()
464                .into_iter()
465                .sorted_by_key(|w| w.id)
466                .collect_vec(),
467            worker_nodes
468        );
469
470        manager.remove_worker_node(worker_nodes[0].clone());
471        assert_eq!(manager.list_serving_worker_nodes().len(), 1);
472        assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
473        assert_eq!(
474            manager
475                .list_compute_nodes()
476                .into_iter()
477                .sorted_by_key(|w| w.id)
478                .collect_vec(),
479            worker_nodes.as_slice()[1..].to_vec()
480        );
481    }
482}