risingwave_simulation/
ctl_ext.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::ffi::OsString;
17use std::sync::Arc;
18
19use anyhow::{Result, anyhow};
20use cfg_or_panic::cfg_or_panic;
21use clap::Parser;
22use itertools::Itertools;
23use rand::seq::IteratorRandom;
24use rand::{Rng, rng as thread_rng};
25use risingwave_common::catalog::TableId;
26use risingwave_common::hash::WorkerSlotId;
27use risingwave_common::id::WorkerId;
28use risingwave_connector::source::{SplitImpl, SplitMetaData};
29use risingwave_hummock_sdk::{CompactionGroupId, HummockSstableId};
30use risingwave_pb::id::{ActorId, FragmentId};
31use risingwave_pb::meta::GetClusterInfoResponse;
32use risingwave_pb::meta::table_fragments::PbFragment;
33use risingwave_pb::meta::update_worker_node_schedulability_request::Schedulability;
34use risingwave_pb::stream_plan::StreamNode;
35
36use self::predicate::BoxedPredicate;
37use crate::cluster::Cluster;
38
39/// Predicates used for locating fragments.
40pub mod predicate {
41    use risingwave_pb::stream_plan::DispatcherType;
42    use risingwave_pb::stream_plan::stream_node::NodeBody;
43
44    use super::*;
45
46    trait Predicate = Fn(&PbFragment) -> bool + Send + 'static;
47    pub type BoxedPredicate = Box<dyn Predicate>;
48
49    fn root(fragment: &PbFragment) -> &StreamNode {
50        fragment.nodes.as_ref().unwrap()
51    }
52
53    fn count(root: &StreamNode, p: &impl Fn(&StreamNode) -> bool) -> usize {
54        let child = root.input.iter().map(|n| count(n, p)).sum::<usize>();
55        child + if p(root) { 1 } else { 0 }
56    }
57
58    fn any(root: &StreamNode, p: &impl Fn(&StreamNode) -> bool) -> bool {
59        p(root) || root.input.iter().any(|n| any(n, p))
60    }
61
62    fn all(root: &StreamNode, p: &impl Fn(&StreamNode) -> bool) -> bool {
63        p(root) && root.input.iter().all(|n| all(n, p))
64    }
65
66    /// There're exactly `n` operators whose identity contains `s` in the fragment.
67    pub fn identity_contains_n(n: usize, s: impl Into<String>) -> BoxedPredicate {
68        let s: String = s.into();
69        let p = move |f: &PbFragment| {
70            count(root(f), &|n| {
71                n.identity.to_lowercase().contains(&s.to_lowercase())
72            }) == n
73        };
74        Box::new(p)
75    }
76
77    /// There exists operators whose identity contains `s` in the fragment (case insensitive).
78    pub fn identity_contains(s: impl Into<String>) -> BoxedPredicate {
79        let s: String = s.into();
80        let p = move |f: &PbFragment| {
81            any(root(f), &|n| {
82                n.identity.to_lowercase().contains(&s.to_lowercase())
83            })
84        };
85        Box::new(p)
86    }
87
88    /// There does not exist any operator whose identity contains `s` in the fragment.
89    pub fn no_identity_contains(s: impl Into<String>) -> BoxedPredicate {
90        let s: String = s.into();
91        let p = move |f: &PbFragment| {
92            all(root(f), &|n| {
93                !n.identity.to_lowercase().contains(&s.to_lowercase())
94            })
95        };
96        Box::new(p)
97    }
98
99    /// The fragment is able to be rescheduled. Used for locating random fragment.
100    pub fn can_reschedule() -> BoxedPredicate {
101        let p = |f: &PbFragment| {
102            // The rescheduling of no-shuffle downstreams must be derived from the most upstream
103            // fragment. So if a fragment has no-shuffle upstreams, it cannot be rescheduled.
104            !any(root(f), &|n| {
105                let Some(NodeBody::Merge(merge)) = &n.node_body else {
106                    return false;
107                };
108                merge.upstream_dispatcher_type() == DispatcherType::NoShuffle
109            })
110        };
111        Box::new(p)
112    }
113
114    /// The fragment with the given id.
115    pub fn id(id: u32) -> BoxedPredicate {
116        let p = move |f: &PbFragment| f.fragment_id == id;
117        Box::new(p)
118    }
119}
120
121#[derive(Debug)]
122pub struct Fragment {
123    pub inner: risingwave_pb::meta::table_fragments::Fragment,
124
125    r: Arc<GetClusterInfoResponse>,
126}
127
128impl Fragment {
129    /// The fragment id.
130    pub fn id(&self) -> FragmentId {
131        self.inner.fragment_id
132    }
133
134    pub fn all_worker_count(&self) -> HashMap<WorkerId, usize> {
135        self.r
136            .worker_nodes
137            .iter()
138            .map(|w| (w.id, w.compute_node_parallelism()))
139            .collect()
140    }
141
142    pub fn all_worker_slots(&self) -> HashSet<WorkerSlotId> {
143        self.all_worker_count()
144            .into_iter()
145            .flat_map(|(k, v)| (0..v).map(move |idx| WorkerSlotId::new(k, idx as _)))
146            .collect()
147    }
148
149    pub fn parallelism(&self) -> usize {
150        self.inner.actors.len()
151    }
152
153    pub fn used_worker_count(&self) -> HashMap<WorkerId, usize> {
154        let actor_to_worker: HashMap<_, _> = self
155            .r
156            .table_fragments
157            .iter()
158            .flat_map(|tf| {
159                tf.actor_status
160                    .iter()
161                    .map(|(&actor_id, status)| (actor_id, status.worker_id()))
162            })
163            .collect();
164
165        self.inner
166            .actors
167            .iter()
168            .map(|a| actor_to_worker[&a.actor_id])
169            .fold(HashMap::<WorkerId, usize>::new(), |mut acc, num| {
170                *acc.entry(num).or_insert(0) += 1;
171                acc
172            })
173    }
174
175    pub fn used_worker_slots(&self) -> HashSet<WorkerSlotId> {
176        self.used_worker_count()
177            .into_iter()
178            .flat_map(|(k, v)| (0..v).map(move |idx| WorkerSlotId::new(k, idx as _)))
179            .collect()
180    }
181}
182
183impl Cluster {
184    /// Locate fragments that satisfy all the predicates.
185    #[cfg_or_panic(madsim)]
186    pub async fn locate_fragments(
187        &mut self,
188        predicates: impl IntoIterator<Item = BoxedPredicate>,
189    ) -> Result<Vec<Fragment>> {
190        let predicates = predicates.into_iter().collect_vec();
191
192        let fragments = self
193            .ctl
194            .spawn(async move {
195                let r: Arc<_> = risingwave_ctl::cmd_impl::meta::get_cluster_info(
196                    &risingwave_ctl::common::CtlContext::default(),
197                )
198                .await?
199                .into();
200
201                let mut results = vec![];
202                for tf in &r.table_fragments {
203                    for f in tf.fragments.values() {
204                        let selected = predicates.iter().all(|p| p(f));
205                        if selected {
206                            results.push(Fragment {
207                                inner: f.clone(),
208                                r: r.clone(),
209                            });
210                        }
211                    }
212                }
213
214                Ok::<_, anyhow::Error>(results)
215            })
216            .await??;
217
218        Ok(fragments)
219    }
220
221    /// Locate exactly one fragment that satisfies all the predicates.
222    pub async fn locate_one_fragment(
223        &mut self,
224        predicates: impl IntoIterator<Item = BoxedPredicate>,
225    ) -> Result<Fragment> {
226        let [fragment]: [_; 1] = self
227            .locate_fragments(predicates)
228            .await?
229            .try_into()
230            .map_err(|fs| anyhow!("not exactly one fragment: {fs:#?}"))?;
231        Ok(fragment)
232    }
233
234    /// Locate a random fragment that is reschedulable.
235    pub async fn locate_random_fragment(&mut self) -> Result<Fragment> {
236        self.locate_fragments([predicate::can_reschedule()])
237            .await?
238            .into_iter()
239            .choose(&mut thread_rng())
240            .ok_or_else(|| anyhow!("no reschedulable fragment"))
241    }
242
243    /// Locate some random fragments that are reschedulable.
244    pub async fn locate_random_fragments(&mut self) -> Result<Vec<Fragment>> {
245        let fragments = self.locate_fragments([predicate::can_reschedule()]).await?;
246        let len = thread_rng().random_range(1..=fragments.len());
247        let selected = fragments
248            .into_iter()
249            .choose_multiple(&mut thread_rng(), len);
250        Ok(selected)
251    }
252
253    /// Locate a fragment with the given id.
254    pub async fn locate_fragment_by_id(&mut self, id: FragmentId) -> Result<Fragment> {
255        self.locate_one_fragment([predicate::id(id.as_raw_id())])
256            .await
257    }
258
259    #[cfg_or_panic(madsim)]
260    pub async fn get_cluster_info(&self) -> Result<GetClusterInfoResponse> {
261        let response = self
262            .ctl
263            .spawn(async move {
264                risingwave_ctl::cmd_impl::meta::get_cluster_info(
265                    &risingwave_ctl::common::CtlContext::default(),
266                )
267                .await
268            })
269            .await??;
270        Ok(response)
271    }
272
273    /// `actor_id -> splits`
274    pub async fn list_source_splits(&self) -> Result<BTreeMap<ActorId, String>> {
275        let info = self.get_cluster_info().await?;
276        let mut res = BTreeMap::new();
277
278        for (actor_id, splits) in info.actor_splits {
279            let splits = splits
280                .splits
281                .iter()
282                .map(|split| SplitImpl::try_from(split).unwrap())
283                .map(|split| split.id())
284                .collect_vec()
285                .join(",");
286            res.insert(actor_id, splits);
287        }
288
289        Ok(res)
290    }
291
292    // update node schedulability
293    #[cfg_or_panic(madsim)]
294    async fn update_worker_node_schedulability(
295        &self,
296        worker_ids: Vec<WorkerId>,
297        target: Schedulability,
298    ) -> Result<()> {
299        let worker_ids = worker_ids
300            .into_iter()
301            .map(|id| id.to_string())
302            .collect_vec();
303
304        let _ = self
305            .ctl
306            .spawn(async move {
307                risingwave_ctl::cmd_impl::scale::update_schedulability(
308                    &risingwave_ctl::common::CtlContext::default(),
309                    worker_ids,
310                    target,
311                )
312                .await
313            })
314            .await?;
315        Ok(())
316    }
317
318    pub async fn cordon_worker(&self, id: WorkerId) -> Result<()> {
319        self.update_worker_node_schedulability(vec![id], Schedulability::Unschedulable)
320            .await
321    }
322
323    pub async fn uncordon_worker(&self, id: WorkerId) -> Result<()> {
324        self.update_worker_node_schedulability(vec![id], Schedulability::Schedulable)
325            .await
326    }
327
328    /// Pause all data sources in the cluster.
329    #[cfg_or_panic(madsim)]
330    pub async fn pause(&mut self) -> Result<()> {
331        self.ctl.spawn(start_ctl(["meta", "pause"])).await??;
332        Ok(())
333    }
334
335    /// Resume all data sources in the cluster.
336    #[cfg_or_panic(madsim)]
337    pub async fn resume(&mut self) -> Result<()> {
338        self.ctl.spawn(start_ctl(["meta", "resume"])).await??;
339        Ok(())
340    }
341
342    /// Throttle a Mv in the cluster
343    #[cfg_or_panic(madsim)]
344    pub async fn throttle_mv(&mut self, table_id: TableId, rate_limit: Option<u32>) -> Result<()> {
345        self.ctl
346            .spawn(async move {
347                let mut command: Vec<String> =
348                    vec!["throttle".into(), "mv".into(), table_id.to_string()];
349                if let Some(rate_limit) = rate_limit {
350                    command.push(rate_limit.to_string());
351                }
352                start_ctl(command).await
353            })
354            .await??;
355        Ok(())
356    }
357
358    #[cfg_or_panic(madsim)]
359    pub async fn split_compaction_group(
360        &mut self,
361        compaction_group_id: CompactionGroupId,
362        table_id: HummockSstableId,
363    ) -> Result<()> {
364        self.ctl
365            .spawn(async move {
366                let mut command: Vec<String> = vec![
367                    "hummock".into(),
368                    "split-compaction-group".into(),
369                    "--compaction-group-id".into(),
370                    compaction_group_id.to_string(),
371                    "--table-ids".into(),
372                    table_id.to_string(),
373                ];
374                start_ctl(command).await
375            })
376            .await??;
377        Ok(())
378    }
379
380    #[cfg_or_panic(madsim)]
381    pub async fn trigger_manual_compaction(
382        &mut self,
383        compaction_group_id: CompactionGroupId,
384        level_id: u32,
385    ) -> Result<()> {
386        self.ctl
387            .spawn(async move {
388                let mut command: Vec<String> = vec![
389                    "hummock".into(),
390                    "trigger-manual-compaction".into(),
391                    "--compaction-group-id".into(),
392                    compaction_group_id.to_string(),
393                    "--level".into(),
394                    level_id.to_string(),
395                ];
396                start_ctl(command).await
397            })
398            .await??;
399        Ok(())
400    }
401}
402
403#[cfg_attr(not(madsim), allow(dead_code))]
404async fn start_ctl<S, I>(args: I) -> Result<()>
405where
406    S: Into<OsString>,
407    I: IntoIterator<Item = S>,
408{
409    let args = std::iter::once("ctl".into()).chain(args.into_iter().map(|s| s.into()));
410    let opts = risingwave_ctl::CliOpts::parse_from(args);
411    let context = risingwave_ctl::common::CtlContext::default();
412    risingwave_ctl::start_fallible(opts, &context).await
413}