risingwave_batch/worker_manager/
worker_node_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::sync::{Arc, RwLock, RwLockReadGuard};
17use std::time::Duration;
18
19use rand::seq::IndexedRandom;
20use risingwave_common::bail;
21use risingwave_common::catalog::OBJECT_ID_PLACEHOLDER;
22use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
23use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
24use risingwave_pb::common::{WorkerNode, WorkerType};
25
26use crate::error::{BatchError, Result};
27
28pub(crate) type FragmentId = u32;
29
30/// `WorkerNodeManager` manages live worker nodes and table vnode mapping information.
31pub struct WorkerNodeManager {
32    inner: RwLock<WorkerNodeManagerInner>,
33    /// Temporarily make worker invisible from serving cluster.
34    worker_node_mask: Arc<RwLock<HashSet<u32>>>,
35}
36
37struct WorkerNodeManagerInner {
38    worker_nodes: Vec<WorkerNode>,
39    /// fragment vnode mapping info for streaming
40    streaming_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
41    /// fragment vnode mapping info for serving
42    serving_fragment_vnode_mapping: HashMap<FragmentId, WorkerSlotMapping>,
43}
44
45pub type WorkerNodeManagerRef = Arc<WorkerNodeManager>;
46
47impl Default for WorkerNodeManager {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl WorkerNodeManager {
54    pub fn new() -> Self {
55        Self {
56            inner: RwLock::new(WorkerNodeManagerInner {
57                worker_nodes: Default::default(),
58                streaming_fragment_vnode_mapping: Default::default(),
59                serving_fragment_vnode_mapping: Default::default(),
60            }),
61            worker_node_mask: Arc::new(Default::default()),
62        }
63    }
64
65    /// Used in tests.
66    pub fn mock(worker_nodes: Vec<WorkerNode>) -> Self {
67        let inner = RwLock::new(WorkerNodeManagerInner {
68            worker_nodes,
69            streaming_fragment_vnode_mapping: HashMap::new(),
70            serving_fragment_vnode_mapping: HashMap::new(),
71        });
72        Self {
73            inner,
74            worker_node_mask: Arc::new(Default::default()),
75        }
76    }
77
78    pub fn list_worker_nodes(&self) -> Vec<WorkerNode> {
79        self.inner
80            .read()
81            .unwrap()
82            .worker_nodes
83            .iter()
84            .filter(|w| w.r#type() == WorkerType::ComputeNode)
85            .cloned()
86            .collect()
87    }
88
89    fn list_serving_worker_nodes(&self) -> Vec<WorkerNode> {
90        self.list_worker_nodes()
91            .into_iter()
92            .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
93            .collect()
94    }
95
96    fn list_streaming_worker_nodes(&self) -> Vec<WorkerNode> {
97        self.list_worker_nodes()
98            .into_iter()
99            .filter(|w| w.property.as_ref().is_some_and(|p| p.is_streaming))
100            .collect()
101    }
102
103    pub fn add_worker_node(&self, node: WorkerNode) {
104        let mut write_guard = self.inner.write().unwrap();
105        match write_guard
106            .worker_nodes
107            .iter_mut()
108            .find(|w| w.id == node.id)
109        {
110            None => {
111                // insert
112                write_guard.worker_nodes.push(node);
113            }
114            Some(w) => {
115                // update
116                *w = node;
117            }
118        }
119    }
120
121    pub fn remove_worker_node(&self, node: WorkerNode) {
122        let mut write_guard = self.inner.write().unwrap();
123        write_guard.worker_nodes.retain(|x| x.id != node.id);
124    }
125
126    pub fn refresh(
127        &self,
128        nodes: Vec<WorkerNode>,
129        streaming_mapping: HashMap<FragmentId, WorkerSlotMapping>,
130        serving_mapping: HashMap<FragmentId, WorkerSlotMapping>,
131    ) {
132        let mut write_guard = self.inner.write().unwrap();
133        tracing::debug!("Refresh worker nodes {:?}.", nodes);
134        tracing::debug!(
135            "Refresh streaming vnode mapping for fragments {:?}.",
136            streaming_mapping.keys()
137        );
138        tracing::debug!(
139            "Refresh serving vnode mapping for fragments {:?}.",
140            serving_mapping.keys()
141        );
142        write_guard.worker_nodes = nodes;
143        write_guard.streaming_fragment_vnode_mapping = streaming_mapping;
144        write_guard.serving_fragment_vnode_mapping = serving_mapping;
145    }
146
147    /// If worker slot ids is empty, the scheduler may fail to schedule any task and stuck at
148    /// schedule next stage. If we do not return error in this case, needs more complex control
149    /// logic above. Report in this function makes the schedule root fail reason more clear.
150    pub fn get_workers_by_worker_slot_ids(
151        &self,
152        worker_slot_ids: &[WorkerSlotId],
153    ) -> Result<Vec<WorkerNode>> {
154        if worker_slot_ids.is_empty() {
155            return Err(BatchError::EmptyWorkerNodes);
156        }
157
158        let guard = self.inner.read().unwrap();
159
160        let worker_index: HashMap<_, _> = guard.worker_nodes.iter().map(|w| (w.id, w)).collect();
161
162        let mut workers = Vec::with_capacity(worker_slot_ids.len());
163
164        for worker_slot_id in worker_slot_ids {
165            match worker_index.get(&worker_slot_id.worker_id()) {
166                Some(worker) => workers.push((*worker).clone()),
167                None => bail!(
168                    "No worker node found for worker slot id: {}",
169                    worker_slot_id
170                ),
171            }
172        }
173
174        Ok(workers)
175    }
176
177    pub fn get_streaming_fragment_mapping(
178        &self,
179        fragment_id: &FragmentId,
180    ) -> Result<WorkerSlotMapping> {
181        self.inner
182            .read()
183            .unwrap()
184            .streaming_fragment_vnode_mapping
185            .get(fragment_id)
186            .cloned()
187            .ok_or_else(|| BatchError::StreamingVnodeMappingNotFound(*fragment_id))
188    }
189
190    pub fn insert_streaming_fragment_mapping(
191        &self,
192        fragment_id: FragmentId,
193        vnode_mapping: WorkerSlotMapping,
194    ) {
195        if self
196            .inner
197            .write()
198            .unwrap()
199            .streaming_fragment_vnode_mapping
200            .try_insert(fragment_id, vnode_mapping)
201            .is_err()
202        {
203            tracing::info!(
204                "Previous batch vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
205            );
206        }
207    }
208
209    pub fn update_streaming_fragment_mapping(
210        &self,
211        fragment_id: FragmentId,
212        vnode_mapping: WorkerSlotMapping,
213    ) {
214        let mut guard = self.inner.write().unwrap();
215        if guard
216            .streaming_fragment_vnode_mapping
217            .insert(fragment_id, vnode_mapping)
218            .is_none()
219        {
220            tracing::info!(
221                "Previous vnode mapping not found for fragment {fragment_id}, maybe offline scaling with background ddl"
222            );
223        }
224    }
225
226    pub fn remove_streaming_fragment_mapping(&self, fragment_id: &FragmentId) {
227        let mut guard = self.inner.write().unwrap();
228
229        let res = guard.streaming_fragment_vnode_mapping.remove(fragment_id);
230        match &res {
231            Some(_) => {}
232            None if OBJECT_ID_PLACEHOLDER == *fragment_id => {
233                // Do nothing for placeholder fragment.
234            }
235            None => {
236                tracing::warn!(fragment_id, "Streaming vnode mapping not found");
237            }
238        };
239    }
240
241    /// Returns fragment's vnode mapping for serving.
242    fn serving_fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
243        self.inner
244            .read()
245            .unwrap()
246            .get_serving_fragment_mapping(fragment_id)
247            .ok_or_else(|| BatchError::ServingVnodeMappingNotFound(fragment_id))
248    }
249
250    pub fn set_serving_fragment_mapping(&self, mappings: HashMap<FragmentId, WorkerSlotMapping>) {
251        let mut guard = self.inner.write().unwrap();
252        tracing::debug!(
253            "Set serving vnode mapping for fragments {:?}",
254            mappings.keys()
255        );
256        guard.serving_fragment_vnode_mapping = mappings;
257    }
258
259    pub fn upsert_serving_fragment_mapping(
260        &self,
261        mappings: HashMap<FragmentId, WorkerSlotMapping>,
262    ) {
263        let mut guard = self.inner.write().unwrap();
264        tracing::debug!(
265            "Upsert serving vnode mapping for fragments {:?}",
266            mappings.keys()
267        );
268        for (fragment_id, mapping) in mappings {
269            guard
270                .serving_fragment_vnode_mapping
271                .insert(fragment_id, mapping);
272        }
273    }
274
275    pub fn remove_serving_fragment_mapping(&self, fragment_ids: &[FragmentId]) {
276        let mut guard = self.inner.write().unwrap();
277        tracing::debug!(
278            "Delete serving vnode mapping for fragments {:?}",
279            fragment_ids
280        );
281        for fragment_id in fragment_ids {
282            guard.serving_fragment_vnode_mapping.remove(fragment_id);
283        }
284    }
285
286    fn worker_node_mask(&self) -> RwLockReadGuard<'_, HashSet<u32>> {
287        self.worker_node_mask.read().unwrap()
288    }
289
290    pub fn mask_worker_node(&self, worker_node_id: u32, duration: Duration) {
291        tracing::info!(
292            "Mask worker node {} for {:?} temporarily",
293            worker_node_id,
294            duration
295        );
296        let mut worker_node_mask = self.worker_node_mask.write().unwrap();
297        if worker_node_mask.contains(&worker_node_id) {
298            return;
299        }
300        worker_node_mask.insert(worker_node_id);
301        let worker_node_mask_ref = self.worker_node_mask.clone();
302        tokio::spawn(async move {
303            tokio::time::sleep(duration).await;
304            worker_node_mask_ref
305                .write()
306                .unwrap()
307                .remove(&worker_node_id);
308        });
309    }
310}
311
312impl WorkerNodeManagerInner {
313    fn get_serving_fragment_mapping(&self, fragment_id: FragmentId) -> Option<WorkerSlotMapping> {
314        self.serving_fragment_vnode_mapping
315            .get(&fragment_id)
316            .cloned()
317    }
318}
319
320/// Selects workers for query according to `enable_barrier_read`
321#[derive(Clone)]
322pub struct WorkerNodeSelector {
323    pub manager: WorkerNodeManagerRef,
324    enable_barrier_read: bool,
325}
326
327impl WorkerNodeSelector {
328    pub fn new(manager: WorkerNodeManagerRef, enable_barrier_read: bool) -> Self {
329        Self {
330            manager,
331            enable_barrier_read,
332        }
333    }
334
335    pub fn worker_node_count(&self) -> usize {
336        if self.enable_barrier_read {
337            self.manager.list_streaming_worker_nodes().len()
338        } else {
339            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
340                .len()
341        }
342    }
343
344    pub fn schedule_unit_count(&self) -> usize {
345        let worker_nodes = if self.enable_barrier_read {
346            self.manager.list_streaming_worker_nodes()
347        } else {
348            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
349        };
350        worker_nodes
351            .iter()
352            .map(|node| node.compute_node_parallelism())
353            .sum()
354    }
355
356    pub fn fragment_mapping(&self, fragment_id: FragmentId) -> Result<WorkerSlotMapping> {
357        if self.enable_barrier_read {
358            self.manager.get_streaming_fragment_mapping(&fragment_id)
359        } else {
360            let mapping = (self.manager.serving_fragment_mapping(fragment_id)).or_else(|_| {
361                tracing::warn!(
362                    fragment_id,
363                    "Serving fragment mapping not found, fall back to streaming one."
364                );
365                self.manager.get_streaming_fragment_mapping(&fragment_id)
366            })?;
367
368            // Filter out unavailable workers.
369            if self.manager.worker_node_mask().is_empty() {
370                Ok(mapping)
371            } else {
372                let workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
373                // If it's a singleton, set max_parallelism=1 for place_vnode.
374                let max_parallelism = mapping.to_single().map(|_| 1);
375                let masked_mapping =
376                    place_vnode(Some(&mapping), &workers, max_parallelism, mapping.len())
377                        .ok_or_else(|| BatchError::EmptyWorkerNodes)?;
378                Ok(masked_mapping)
379            }
380        }
381    }
382
383    pub fn next_random_worker(&self) -> Result<WorkerNode> {
384        let worker_nodes = if self.enable_barrier_read {
385            self.manager.list_streaming_worker_nodes()
386        } else {
387            self.apply_worker_node_mask(self.manager.list_serving_worker_nodes())
388        };
389        worker_nodes
390            .choose(&mut rand::rng())
391            .ok_or_else(|| BatchError::EmptyWorkerNodes)
392            .map(|w| (*w).clone())
393    }
394
395    fn apply_worker_node_mask(&self, origin: Vec<WorkerNode>) -> Vec<WorkerNode> {
396        let mask = self.manager.worker_node_mask();
397        if origin.iter().all(|w| mask.contains(&w.id)) {
398            return origin;
399        }
400        origin
401            .into_iter()
402            .filter(|w| !mask.contains(&w.id))
403            .collect()
404    }
405}
406
407#[cfg(test)]
408mod tests {
409
410    use risingwave_common::util::addr::HostAddr;
411    use risingwave_pb::common::worker_node;
412    use risingwave_pb::common::worker_node::Property;
413
414    #[test]
415    fn test_worker_node_manager() {
416        use super::*;
417
418        let manager = WorkerNodeManager::mock(vec![]);
419        assert_eq!(manager.list_serving_worker_nodes().len(), 0);
420        assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
421        assert_eq!(manager.list_worker_nodes(), vec![]);
422
423        let worker_nodes = vec![
424            WorkerNode {
425                id: 1,
426                r#type: WorkerType::ComputeNode as i32,
427                host: Some(HostAddr::try_from("127.0.0.1:1234").unwrap().to_protobuf()),
428                state: worker_node::State::Running as i32,
429                property: Some(Property {
430                    is_unschedulable: false,
431                    is_serving: true,
432                    is_streaming: true,
433                    ..Default::default()
434                }),
435                transactional_id: Some(1),
436                ..Default::default()
437            },
438            WorkerNode {
439                id: 2,
440                r#type: WorkerType::ComputeNode as i32,
441                host: Some(HostAddr::try_from("127.0.0.1:1235").unwrap().to_protobuf()),
442                state: worker_node::State::Running as i32,
443                property: Some(Property {
444                    is_unschedulable: false,
445                    is_serving: true,
446                    is_streaming: false,
447                    ..Default::default()
448                }),
449                transactional_id: Some(2),
450                ..Default::default()
451            },
452        ];
453        worker_nodes
454            .iter()
455            .for_each(|w| manager.add_worker_node(w.clone()));
456        assert_eq!(manager.list_serving_worker_nodes().len(), 2);
457        assert_eq!(manager.list_streaming_worker_nodes().len(), 1);
458        assert_eq!(manager.list_worker_nodes(), worker_nodes);
459
460        manager.remove_worker_node(worker_nodes[0].clone());
461        assert_eq!(manager.list_serving_worker_nodes().len(), 1);
462        assert_eq!(manager.list_streaming_worker_nodes().len(), 0);
463        assert_eq!(
464            manager.list_worker_nodes(),
465            worker_nodes.as_slice()[1..].to_vec()
466        );
467    }
468}