risingwave_ctl/cmd_impl/scale/
resize.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::process::exit;
17
18use itertools::Itertools;
19use risingwave_pb::meta::GetClusterInfoResponse;
20use risingwave_pb::meta::update_worker_node_schedulability_request::Schedulability;
21use thiserror_ext::AsReport;
22
23use crate::common::CtlContext;
24
25macro_rules! fail {
26    ($($arg:tt)*) => {{
27        println!($($arg)*);
28        exit(1);
29    }};
30}
31
32pub async fn update_schedulability(
33    context: &CtlContext,
34    workers: Vec<String>,
35    target: Schedulability,
36) -> anyhow::Result<()> {
37    let meta_client = context.meta_client().await?;
38
39    let GetClusterInfoResponse { worker_nodes, .. } = match meta_client.get_cluster_info().await {
40        Ok(resp) => resp,
41        Err(e) => {
42            fail!("Failed to get cluster info: {}", e.as_report());
43        }
44    };
45
46    let worker_ids: HashSet<_> = worker_nodes.iter().map(|worker| worker.id).collect();
47
48    let worker_index_by_host: HashMap<_, _> = worker_nodes
49        .iter()
50        .map(|worker| {
51            let host = worker.get_host().expect("worker host must be set");
52            (format!("{}:{}", host.host, host.port), worker.id)
53        })
54        .collect();
55
56    let mut target_worker_ids = HashSet::new();
57
58    for worker in workers {
59        let worker_id = worker
60            .parse::<u32>()
61            .ok()
62            .or_else(|| worker_index_by_host.get(&worker).cloned());
63
64        if let Some(worker_id) = worker_id
65            && worker_ids.contains(&worker_id)
66        {
67            if !target_worker_ids.insert(worker_id) {
68                println!("Warn: {} and {} are the same worker", worker, worker_id);
69            }
70        } else {
71            fail!("Invalid worker id: {}", worker);
72        }
73    }
74
75    let target_worker_ids = target_worker_ids.into_iter().collect_vec();
76
77    meta_client
78        .update_schedulability(&target_worker_ids, target)
79        .await?;
80
81    Ok(())
82}