risingwave_simulation/
ctl_ext.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::ffi::OsString;
17use std::fmt::Write;
18use std::sync::Arc;
19
20use anyhow::{Result, anyhow};
21use cfg_or_panic::cfg_or_panic;
22use clap::Parser;
23use itertools::Itertools;
24use rand::seq::IteratorRandom;
25use rand::{Rng, rng as thread_rng};
26use risingwave_common::catalog::TableId;
27use risingwave_common::hash::WorkerSlotId;
28use risingwave_connector::source::{SplitImpl, SplitMetaData};
29use risingwave_hummock_sdk::{CompactionGroupId, HummockSstableId};
30use risingwave_pb::meta::GetClusterInfoResponse;
31use risingwave_pb::meta::table_fragments::PbFragment;
32use risingwave_pb::meta::table_fragments::fragment::FragmentDistributionType;
33use risingwave_pb::meta::update_worker_node_schedulability_request::Schedulability;
34use risingwave_pb::stream_plan::StreamNode;
35
36use self::predicate::BoxedPredicate;
37use crate::cluster::Cluster;
38
39/// Predicates used for locating fragments.
40pub mod predicate {
41    use risingwave_pb::stream_plan::DispatcherType;
42    use risingwave_pb::stream_plan::stream_node::NodeBody;
43
44    use super::*;
45
46    trait Predicate = Fn(&PbFragment) -> bool + Send + 'static;
47    pub type BoxedPredicate = Box<dyn Predicate>;
48
49    fn root(fragment: &PbFragment) -> &StreamNode {
50        fragment.nodes.as_ref().unwrap()
51    }
52
53    fn count(root: &StreamNode, p: &impl Fn(&StreamNode) -> bool) -> usize {
54        let child = root.input.iter().map(|n| count(n, p)).sum::<usize>();
55        child + if p(root) { 1 } else { 0 }
56    }
57
58    fn any(root: &StreamNode, p: &impl Fn(&StreamNode) -> bool) -> bool {
59        p(root) || root.input.iter().any(|n| any(n, p))
60    }
61
62    fn all(root: &StreamNode, p: &impl Fn(&StreamNode) -> bool) -> bool {
63        p(root) && root.input.iter().all(|n| all(n, p))
64    }
65
66    /// There're exactly `n` operators whose identity contains `s` in the fragment.
67    pub fn identity_contains_n(n: usize, s: impl Into<String>) -> BoxedPredicate {
68        let s: String = s.into();
69        let p = move |f: &PbFragment| {
70            count(root(f), &|n| {
71                n.identity.to_lowercase().contains(&s.to_lowercase())
72            }) == n
73        };
74        Box::new(p)
75    }
76
77    /// There exists operators whose identity contains `s` in the fragment (case insensitive).
78    pub fn identity_contains(s: impl Into<String>) -> BoxedPredicate {
79        let s: String = s.into();
80        let p = move |f: &PbFragment| {
81            any(root(f), &|n| {
82                n.identity.to_lowercase().contains(&s.to_lowercase())
83            })
84        };
85        Box::new(p)
86    }
87
88    /// There does not exist any operator whose identity contains `s` in the fragment.
89    pub fn no_identity_contains(s: impl Into<String>) -> BoxedPredicate {
90        let s: String = s.into();
91        let p = move |f: &PbFragment| {
92            all(root(f), &|n| {
93                !n.identity.to_lowercase().contains(&s.to_lowercase())
94            })
95        };
96        Box::new(p)
97    }
98
99    /// The fragment is able to be rescheduled. Used for locating random fragment.
100    pub fn can_reschedule() -> BoxedPredicate {
101        let p = |f: &PbFragment| {
102            // The rescheduling of no-shuffle downstreams must be derived from the most upstream
103            // fragment. So if a fragment has no-shuffle upstreams, it cannot be rescheduled.
104            !any(root(f), &|n| {
105                let Some(NodeBody::Merge(merge)) = &n.node_body else {
106                    return false;
107                };
108                merge.upstream_dispatcher_type() == DispatcherType::NoShuffle
109            })
110        };
111        Box::new(p)
112    }
113
114    /// The fragment with the given id.
115    pub fn id(id: u32) -> BoxedPredicate {
116        let p = move |f: &PbFragment| f.fragment_id == id;
117        Box::new(p)
118    }
119}
120
121#[derive(Debug)]
122pub struct Fragment {
123    pub inner: risingwave_pb::meta::table_fragments::Fragment,
124
125    r: Arc<GetClusterInfoResponse>,
126}
127
128impl Fragment {
129    /// The fragment id.
130    pub fn id(&self) -> u32 {
131        self.inner.fragment_id
132    }
133
134    /// Generate a reschedule plan for the fragment.
135    pub fn reschedule(
136        &self,
137        remove: impl AsRef<[WorkerSlotId]>,
138        add: impl AsRef<[WorkerSlotId]>,
139    ) -> String {
140        let remove = remove.as_ref();
141        let add = add.as_ref();
142
143        let mut worker_decreased = HashMap::new();
144        for worker_slot in remove {
145            let worker_id = worker_slot.worker_id();
146            *worker_decreased.entry(worker_id).or_insert(0) += 1;
147        }
148
149        let mut worker_increased = HashMap::new();
150        for worker_slot in add {
151            let worker_id = worker_slot.worker_id();
152            *worker_increased.entry(worker_id).or_insert(0) += 1;
153        }
154
155        let worker_ids: HashSet<_> = worker_increased
156            .keys()
157            .chain(worker_decreased.keys())
158            .cloned()
159            .collect();
160
161        let mut worker_actor_diff = HashMap::new();
162
163        for worker_id in worker_ids {
164            let increased = worker_increased.remove(&worker_id).unwrap_or(0);
165            let decreased = worker_decreased.remove(&worker_id).unwrap_or(0);
166            let diff = increased - decreased;
167            if diff != 0 {
168                worker_actor_diff.insert(worker_id, diff);
169            }
170        }
171
172        let mut f = String::new();
173
174        if !worker_actor_diff.is_empty() {
175            let worker_diff_str = worker_actor_diff
176                .into_iter()
177                .map(|(k, v)| format!("{}:{}", k, v))
178                .join(", ");
179
180            write!(f, "{}", self.id()).unwrap();
181            write!(f, ":[{}]", worker_diff_str).unwrap();
182        }
183
184        f
185    }
186
187    /// Generate a random reschedule plan for the fragment.
188    ///
189    /// Consumes `self` as the actor info will be stale after rescheduling.
190    pub fn random_reschedule(self) -> String {
191        let all_worker_slots = self.all_worker_slots();
192        let used_worker_slots = self.used_worker_slots();
193
194        let rng = &mut thread_rng();
195        let target_worker_slot_count = match self.inner.distribution_type() {
196            FragmentDistributionType::Unspecified => unreachable!(),
197            FragmentDistributionType::Single => 1,
198            FragmentDistributionType::Hash => rng.random_range(1..=all_worker_slots.len()),
199        };
200
201        let target_worker_slots: HashSet<_> = all_worker_slots
202            .into_iter()
203            .choose_multiple(rng, target_worker_slot_count)
204            .into_iter()
205            .collect();
206
207        let remove = used_worker_slots
208            .difference(&target_worker_slots)
209            .copied()
210            .collect_vec();
211
212        let add = target_worker_slots
213            .difference(&used_worker_slots)
214            .copied()
215            .collect_vec();
216
217        self.reschedule(remove, add)
218    }
219
220    pub fn all_worker_count(&self) -> HashMap<u32, usize> {
221        self.r
222            .worker_nodes
223            .iter()
224            .map(|w| (w.id, w.compute_node_parallelism()))
225            .collect()
226    }
227
228    pub fn all_worker_slots(&self) -> HashSet<WorkerSlotId> {
229        self.all_worker_count()
230            .into_iter()
231            .flat_map(|(k, v)| (0..v).map(move |idx| WorkerSlotId::new(k, idx as _)))
232            .collect()
233    }
234
235    pub fn parallelism(&self) -> usize {
236        self.inner.actors.len()
237    }
238
239    pub fn used_worker_count(&self) -> HashMap<u32, usize> {
240        let actor_to_worker: HashMap<_, _> = self
241            .r
242            .table_fragments
243            .iter()
244            .flat_map(|tf| {
245                tf.actor_status
246                    .iter()
247                    .map(|(&actor_id, status)| (actor_id, status.worker_id()))
248            })
249            .collect();
250
251        self.inner
252            .actors
253            .iter()
254            .map(|a| actor_to_worker[&a.actor_id])
255            .fold(HashMap::<u32, usize>::new(), |mut acc, num| {
256                *acc.entry(num).or_insert(0) += 1;
257                acc
258            })
259    }
260
261    pub fn used_worker_slots(&self) -> HashSet<WorkerSlotId> {
262        self.used_worker_count()
263            .into_iter()
264            .flat_map(|(k, v)| (0..v).map(move |idx| WorkerSlotId::new(k, idx as _)))
265            .collect()
266    }
267}
268
269impl Cluster {
270    /// Locate fragments that satisfy all the predicates.
271    #[cfg_or_panic(madsim)]
272    pub async fn locate_fragments(
273        &mut self,
274        predicates: impl IntoIterator<Item = BoxedPredicate>,
275    ) -> Result<Vec<Fragment>> {
276        let predicates = predicates.into_iter().collect_vec();
277
278        let fragments = self
279            .ctl
280            .spawn(async move {
281                let r: Arc<_> = risingwave_ctl::cmd_impl::meta::get_cluster_info(
282                    &risingwave_ctl::common::CtlContext::default(),
283                )
284                .await?
285                .into();
286
287                let mut results = vec![];
288                for tf in &r.table_fragments {
289                    for f in tf.fragments.values() {
290                        let selected = predicates.iter().all(|p| p(f));
291                        if selected {
292                            results.push(Fragment {
293                                inner: f.clone(),
294                                r: r.clone(),
295                            });
296                        }
297                    }
298                }
299
300                Ok::<_, anyhow::Error>(results)
301            })
302            .await??;
303
304        Ok(fragments)
305    }
306
307    /// Locate exactly one fragment that satisfies all the predicates.
308    pub async fn locate_one_fragment(
309        &mut self,
310        predicates: impl IntoIterator<Item = BoxedPredicate>,
311    ) -> Result<Fragment> {
312        let [fragment]: [_; 1] = self
313            .locate_fragments(predicates)
314            .await?
315            .try_into()
316            .map_err(|fs| anyhow!("not exactly one fragment: {fs:#?}"))?;
317        Ok(fragment)
318    }
319
320    /// Locate a random fragment that is reschedulable.
321    pub async fn locate_random_fragment(&mut self) -> Result<Fragment> {
322        self.locate_fragments([predicate::can_reschedule()])
323            .await?
324            .into_iter()
325            .choose(&mut thread_rng())
326            .ok_or_else(|| anyhow!("no reschedulable fragment"))
327    }
328
329    /// Locate some random fragments that are reschedulable.
330    pub async fn locate_random_fragments(&mut self) -> Result<Vec<Fragment>> {
331        let fragments = self.locate_fragments([predicate::can_reschedule()]).await?;
332        let len = thread_rng().random_range(1..=fragments.len());
333        let selected = fragments
334            .into_iter()
335            .choose_multiple(&mut thread_rng(), len);
336        Ok(selected)
337    }
338
339    /// Locate a fragment with the given id.
340    pub async fn locate_fragment_by_id(&mut self, id: u32) -> Result<Fragment> {
341        self.locate_one_fragment([predicate::id(id)]).await
342    }
343
344    #[cfg_or_panic(madsim)]
345    pub async fn get_cluster_info(&self) -> Result<GetClusterInfoResponse> {
346        let response = self
347            .ctl
348            .spawn(async move {
349                risingwave_ctl::cmd_impl::meta::get_cluster_info(
350                    &risingwave_ctl::common::CtlContext::default(),
351                )
352                .await
353            })
354            .await??;
355        Ok(response)
356    }
357
358    /// `actor_id -> splits`
359    pub async fn list_source_splits(&self) -> Result<BTreeMap<u32, String>> {
360        let info = self.get_cluster_info().await?;
361        let mut res = BTreeMap::new();
362
363        for table in info.table_fragments {
364            for (actor_id, splits) in table.actor_splits {
365                let splits = splits
366                    .splits
367                    .iter()
368                    .map(|split| SplitImpl::try_from(split).unwrap())
369                    .map(|split| split.id())
370                    .collect_vec()
371                    .join(",");
372                res.insert(actor_id, splits);
373            }
374        }
375
376        Ok(res)
377    }
378
379    // update node schedulability
380    #[cfg_or_panic(madsim)]
381    async fn update_worker_node_schedulability(
382        &self,
383        worker_ids: Vec<u32>,
384        target: Schedulability,
385    ) -> Result<()> {
386        let worker_ids = worker_ids
387            .into_iter()
388            .map(|id| id.to_string())
389            .collect_vec();
390
391        let _ = self
392            .ctl
393            .spawn(async move {
394                risingwave_ctl::cmd_impl::scale::update_schedulability(
395                    &risingwave_ctl::common::CtlContext::default(),
396                    worker_ids,
397                    target,
398                )
399                .await
400            })
401            .await?;
402        Ok(())
403    }
404
405    pub async fn cordon_worker(&self, id: u32) -> Result<()> {
406        self.update_worker_node_schedulability(vec![id], Schedulability::Unschedulable)
407            .await
408    }
409
410    pub async fn uncordon_worker(&self, id: u32) -> Result<()> {
411        self.update_worker_node_schedulability(vec![id], Schedulability::Schedulable)
412            .await
413    }
414
415    /// Reschedule with the given `plan`. Check the document of
416    /// [`risingwave_ctl::cmd_impl::meta::reschedule`] for more details.
417    pub async fn reschedule(&mut self, plan: impl Into<String>) -> Result<()> {
418        self.reschedule_helper(plan, false).await
419    }
420
421    /// Same as reschedule, but resolve the no-shuffle upstream
422    pub async fn reschedule_resolve_no_shuffle(&mut self, plan: impl Into<String>) -> Result<()> {
423        self.reschedule_helper(plan, true).await
424    }
425
426    #[cfg_or_panic(madsim)]
427    async fn reschedule_helper(
428        &mut self,
429        plan: impl Into<String>,
430        resolve_no_shuffle_upstream: bool,
431    ) -> Result<()> {
432        let plan = plan.into();
433
434        let revision = self
435            .ctl
436            .spawn(async move {
437                let r = risingwave_ctl::cmd_impl::meta::get_cluster_info(
438                    &risingwave_ctl::common::CtlContext::default(),
439                )
440                .await?;
441
442                Ok::<_, anyhow::Error>(r.revision)
443            })
444            .await??;
445
446        self.ctl
447            .spawn(async move {
448                let revision = format!("{}", revision);
449                let mut v = vec![
450                    "meta",
451                    "reschedule",
452                    "--plan",
453                    plan.as_ref(),
454                    "--revision",
455                    &revision,
456                ];
457
458                if resolve_no_shuffle_upstream {
459                    v.push("--resolve-no-shuffle");
460                }
461
462                start_ctl(v).await
463            })
464            .await??;
465
466        Ok(())
467    }
468
469    /// Pause all data sources in the cluster.
470    #[cfg_or_panic(madsim)]
471    pub async fn pause(&mut self) -> Result<()> {
472        self.ctl.spawn(start_ctl(["meta", "pause"])).await??;
473        Ok(())
474    }
475
476    /// Resume all data sources in the cluster.
477    #[cfg_or_panic(madsim)]
478    pub async fn resume(&mut self) -> Result<()> {
479        self.ctl.spawn(start_ctl(["meta", "resume"])).await??;
480        Ok(())
481    }
482
483    /// Throttle a Mv in the cluster
484    #[cfg_or_panic(madsim)]
485    pub async fn throttle_mv(&mut self, table_id: TableId, rate_limit: Option<u32>) -> Result<()> {
486        self.ctl
487            .spawn(async move {
488                let mut command: Vec<String> = vec![
489                    "throttle".into(),
490                    "mv".into(),
491                    table_id.table_id.to_string(),
492                ];
493                if let Some(rate_limit) = rate_limit {
494                    command.push(rate_limit.to_string());
495                }
496                start_ctl(command).await
497            })
498            .await??;
499        Ok(())
500    }
501
502    #[cfg_or_panic(madsim)]
503    pub async fn split_compaction_group(
504        &mut self,
505        compaction_group_id: CompactionGroupId,
506        table_id: HummockSstableId,
507    ) -> Result<()> {
508        self.ctl
509            .spawn(async move {
510                let mut command: Vec<String> = vec![
511                    "hummock".into(),
512                    "split-compaction-group".into(),
513                    "--compaction-group-id".into(),
514                    compaction_group_id.to_string(),
515                    "--table-ids".into(),
516                    table_id.to_string(),
517                ];
518                start_ctl(command).await
519            })
520            .await??;
521        Ok(())
522    }
523
524    #[cfg_or_panic(madsim)]
525    pub async fn trigger_manual_compaction(
526        &mut self,
527        compaction_group_id: CompactionGroupId,
528        level_id: u32,
529    ) -> Result<()> {
530        self.ctl
531            .spawn(async move {
532                let mut command: Vec<String> = vec![
533                    "hummock".into(),
534                    "trigger-manual-compaction".into(),
535                    "--compaction-group-id".into(),
536                    compaction_group_id.to_string(),
537                    "--level".into(),
538                    level_id.to_string(),
539                ];
540                start_ctl(command).await
541            })
542            .await??;
543        Ok(())
544    }
545}
546
547#[cfg_attr(not(madsim), allow(dead_code))]
548async fn start_ctl<S, I>(args: I) -> Result<()>
549where
550    S: Into<OsString>,
551    I: IntoIterator<Item = S>,
552{
553    let args = std::iter::once("ctl".into()).chain(args.into_iter().map(|s| s.into()));
554    let opts = risingwave_ctl::CliOpts::parse_from(args);
555    let context = risingwave_ctl::common::CtlContext::default();
556    risingwave_ctl::start_fallible(opts, &context).await
557}