risingwave_common_rate_limit/
lib.rs1use std::future::Future;
16use std::num::NonZeroU64;
17use std::ops::Deref;
18use std::pin::Pin;
19use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
20use std::sync::{Arc, LazyLock};
21use std::task::{Context, Poll};
22use std::time::{Duration, Instant};
23
24use arc_swap::ArcSwap;
25use parking_lot::Mutex;
26use pin_project_lite::pin_project;
27use risingwave_common::catalog::TableId;
28use risingwave_common::metrics::LabelGuardedUintGaugeVec;
29use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
30use risingwave_common_metrics::{
31 LabelGuardedUintGauge, register_guarded_uint_gauge_vec_with_registry,
32};
33use tokio::sync::oneshot;
34use tokio::time::Sleep;
35
36static METRICS: LazyLock<LabelGuardedUintGaugeVec<1>> = LazyLock::new(|| {
37 register_guarded_uint_gauge_vec_with_registry!(
38 "backfill_rate_limit_bytes",
39 "backfill rate limit bytes per second",
40 &["table_id"],
41 &GLOBAL_METRICS_REGISTRY
42 )
43 .unwrap()
44});
45
46pin_project! {
47 #[derive(Debug)]
48 #[project = DelayProj]
49 pub enum Delay {
50 Noop,
51 Sleep{#[pin] sleep: Sleep},
52 Wait{#[pin] rx: oneshot::Receiver<()> },
53 Infinite,
54 }
55}
56
57impl Delay {
58 pub fn new(duration: Duration) -> Self {
59 match duration {
60 Duration::ZERO => Self::Noop,
61 Duration::MAX => Self::Infinite,
62 dur => Self::Sleep {
63 sleep: tokio::time::sleep(dur),
64 },
65 }
66 }
67}
68
69impl From<Duration> for Delay {
70 fn from(value: Duration) -> Self {
71 Self::new(value)
72 }
73}
74
75impl Future for Delay {
76 type Output = ();
77
78 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79 match self.project() {
80 DelayProj::Noop => Poll::Ready(()),
81 DelayProj::Sleep { sleep } => sleep.poll(cx),
82 DelayProj::Wait { rx } => rx.poll(cx).map(|_| ()),
83 DelayProj::Infinite => Poll::Pending,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum RateLimit {
91 Disabled,
93 Fixed(NonZeroU64),
95 Pause,
97}
98
99impl RateLimit {
100 pub fn is_paused(&self) -> bool {
102 matches! { self, Self::Pause }
103 }
104
105 pub fn to_u64(self) -> u64 {
106 self.into()
107 }
108}
109
110impl From<RateLimit> for u64 {
111 fn from(rate_limit: RateLimit) -> Self {
112 match rate_limit {
113 RateLimit::Disabled => u64::MAX,
114 RateLimit::Fixed(rate) => rate.get(),
115 RateLimit::Pause => 0,
116 }
117 }
118}
119
120impl From<Option<u32>> for RateLimit {
122 fn from(value: Option<u32>) -> Self {
123 match value {
124 None => Self::Disabled,
125 Some(0) => Self::Pause,
126 Some(rate) => Self::Fixed(unsafe { NonZeroU64::new_unchecked(rate as _) }),
127 }
128 }
129}
130
131#[derive(Debug)]
132pub enum Check {
133 Ok,
134 Retry(Duration),
135 RetryAfter(oneshot::Receiver<()>),
136}
137
138impl Check {
139 pub fn is_ok(&self) -> bool {
140 matches!(self, Self::Ok)
141 }
142}
143
144pub trait RateLimiterTrait: Send + Sync + 'static {
146 fn rate_limit(&self) -> RateLimit;
148
149 fn check(&self, quota: u64) -> Check;
157}
158
159pub struct RateLimiter {
161 inner: ArcSwap<Box<dyn RateLimiterTrait>>,
162}
163
164impl RateLimiter {
165 fn new_inner(rate_limit: RateLimit) -> Box<dyn RateLimiterTrait> {
166 match rate_limit {
167 RateLimit::Disabled => Box::new(InfiniteRatelimiter),
168 RateLimit::Fixed(rate) => Box::new(FixedRateLimiter::new(rate)),
169 RateLimit::Pause => Box::new(PausedRateLimiter::default()),
170 }
171 }
172
173 pub fn new(rate_limit: RateLimit) -> Self {
175 let inner: Box<dyn RateLimiterTrait> = Self::new_inner(rate_limit);
176 let inner = ArcSwap::new(Arc::new(inner));
177 Self { inner }
178 }
179
180 pub fn update(&self, rate_limit: RateLimit) -> RateLimit {
184 let old = self.rate_limit();
185 if self.rate_limit() == rate_limit {
186 return old;
187 }
188 let inner = Self::new_inner(rate_limit);
189 self.inner.store(Arc::new(inner));
190 old
191 }
192
193 pub fn monitored(self, table_id: impl Into<TableId>) -> MonitoredRateLimiter {
195 let metric = METRICS.with_guarded_label_values(&[&table_id.into().to_string()]);
196 let rate_limit = AtomicU64::new(self.rate_limit().to_u64());
197 MonitoredRateLimiter {
198 inner: self,
199 metric,
200 rate_limit,
201 }
202 }
203
204 pub fn rate_limit(&self) -> RateLimit {
205 self.inner.load().rate_limit()
206 }
207
208 pub fn check(&self, quota: u64) -> Check {
209 self.inner.load().check(quota)
210 }
211
212 pub async fn wait(&self, quota: u64) {
213 loop {
214 match self.check(quota) {
215 Check::Ok => return,
216 Check::Retry(duration) => {
217 tokio::time::sleep(duration).await;
218 }
219 Check::RetryAfter(rx) => {
220 let _ = rx.await;
221 }
222 }
223 }
224 }
225}
226
227impl RateLimiterTrait for RateLimiter {
228 fn rate_limit(&self) -> RateLimit {
230 self.rate_limit()
231 }
232
233 fn check(&self, quota: u64) -> Check {
241 self.check(quota)
242 }
243}
244
245pub struct MonitoredRateLimiter {
247 inner: RateLimiter,
248 metric: LabelGuardedUintGauge<1>,
249 rate_limit: AtomicU64,
250}
251
252impl Deref for MonitoredRateLimiter {
253 type Target = RateLimiter;
254
255 fn deref(&self) -> &Self::Target {
256 &self.inner
257 }
258}
259
260impl RateLimiterTrait for MonitoredRateLimiter {
261 fn rate_limit(&self) -> RateLimit {
262 self.inner.rate_limit()
263 }
264
265 fn check(&self, quota: u64) -> Check {
266 let check = self.inner.check(quota);
267 if matches! { check, Check::Ok} {
268 self.report();
269 }
270 check
271 }
272}
273
274impl MonitoredRateLimiter {
275 fn report(&self) {
279 let rate_limit = self.inner.rate_limit().to_u64();
280 if rate_limit != self.rate_limit.load(Ordering::Relaxed) {
281 self.rate_limit.store(rate_limit, Ordering::Relaxed);
282 self.metric.set(rate_limit);
283 }
284 }
285}
286
287#[derive(Debug)]
288pub struct InfiniteRatelimiter;
289
290impl RateLimiterTrait for InfiniteRatelimiter {
291 fn rate_limit(&self) -> RateLimit {
292 RateLimit::Disabled
293 }
294
295 fn check(&self, _: u64) -> Check {
296 Check::Ok
297 }
298}
299
300#[derive(Debug)]
301pub struct PausedRateLimiter {
302 waiters: Mutex<Vec<oneshot::Sender<()>>>,
303}
304
305impl Default for PausedRateLimiter {
306 fn default() -> Self {
307 Self {
308 waiters: Mutex::new(vec![]),
309 }
310 }
311}
312
313impl Drop for PausedRateLimiter {
314 fn drop(&mut self) {
315 for tx in self.waiters.lock().drain(..) {
316 let _ = tx.send(());
317 }
318 }
319}
320
321impl RateLimiterTrait for PausedRateLimiter {
322 fn rate_limit(&self) -> RateLimit {
323 RateLimit::Pause
324 }
325
326 fn check(&self, _: u64) -> Check {
327 let (tx, rx) = oneshot::channel();
328 self.waiters.lock().push(tx);
329 Check::RetryAfter(rx)
330 }
331}
332
333#[derive(Debug)]
334pub struct FixedRateLimiter {
335 inner: LeakBucket,
336 rate: NonZeroU64,
337}
338
339impl FixedRateLimiter {
340 pub fn new(rate: NonZeroU64) -> Self {
341 let inner = LeakBucket::new(rate);
342 Self { inner, rate }
343 }
344}
345
346impl RateLimiterTrait for FixedRateLimiter {
347 fn rate_limit(&self) -> RateLimit {
348 RateLimit::Fixed(self.rate)
349 }
350
351 fn check(&self, quota: u64) -> Check {
352 match self.inner.check(quota) {
353 Ok(()) => Check::Ok,
354 Err(duration) => Check::Retry(duration),
355 }
356 }
357}
358
359#[derive(Debug)]
361pub struct LeakBucket {
362 scale: AtomicU64,
368
369 ltat: AtomicU64,
371
372 origin: Instant,
374
375 total_allowed_quotas: AtomicU64,
377 total_waited_nanos: AtomicI64,
379}
380
381impl LeakBucket {
382 const NANO: u64 = Duration::from_secs(1).as_nanos() as u64;
383
384 fn scale(rate: NonZeroU64) -> u64 {
386 std::cmp::max(Self::NANO / rate.get(), 1)
387 }
388
389 fn new(rate: NonZeroU64) -> Self {
391 let scale = Self::scale(rate);
392
393 let origin = Instant::now();
394 let scale = AtomicU64::new(scale);
395
396 Self {
397 scale,
398 ltat: AtomicU64::new(0),
399 origin,
400 total_allowed_quotas: AtomicU64::new(0),
401 total_waited_nanos: AtomicI64::new(0),
402 }
403 }
404
405 fn check(&self, quota: u64) -> Result<(), Duration> {
411 let now = Instant::now();
412 let tnow = now.duration_since(self.origin).as_nanos() as u64;
413
414 let weight = quota * self.scale.load(Ordering::Relaxed);
415
416 let mut ltat = self.ltat.load(Ordering::Acquire);
417 let tat = loop {
418 let tat = ltat + weight;
419
420 if tat > tnow {
421 self.total_waited_nanos
422 .fetch_add((tat - tnow) as i64, Ordering::Relaxed);
423 return Err(Duration::from_nanos(tat - tnow));
424 }
425
426 let ltat_new = std::cmp::max(tat, tnow);
427
428 match self
429 .ltat
430 .compare_exchange(ltat, ltat_new, Ordering::Release, Ordering::Acquire)
431 {
432 Ok(_) => break tat,
433 Err(cur) => ltat = cur,
434 }
435 };
436
437 self.total_allowed_quotas
438 .fetch_add(quota, Ordering::Relaxed);
439 self.total_waited_nanos
440 .fetch_sub((tnow - tat) as i64, Ordering::Relaxed);
441
442 Ok(())
443 }
444
445 fn _avg_wait_nanos_per_quota(&self) -> i64 {
450 let quotas = self.total_allowed_quotas.load(Ordering::Relaxed);
451 if quotas == 0 {
452 0
453 } else {
454 let nanos = self.total_waited_nanos.load(Ordering::Relaxed);
455 nanos / quotas as i64
456 }
457 }
458
459 fn _reset_stats(&self) {
462 self.total_allowed_quotas.store(0, Ordering::Relaxed);
463 self.total_waited_nanos.store(0, Ordering::Relaxed);
464 }
465
466 fn _update(&self, rate: NonZeroU64) {
469 let scale = Self::scale(rate);
470 self.scale.store(scale, Ordering::Relaxed);
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use std::sync::Arc;
477 use std::sync::atomic::Ordering;
478
479 use rand::{Rng, rng as thread_rng};
480
481 use super::*;
482
483 const ERATIO: f64 = 0.05;
484 const THREADS: usize = 8;
485 const RATE: u64 = 1000;
486 const DURATION: Duration = Duration::from_secs(10);
487
488 #[ignore]
494 #[test]
495 fn test_leak_bucket() {
496 let v = Arc::new(AtomicU64::new(0));
497 let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
498 let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
499 let start = Instant::now();
500 loop {
501 if start.elapsed() >= DURATION {
502 break;
503 }
504 while let Err(dur) = vs.check(quota) {
505 std::thread::sleep(dur);
506 }
507 if start.elapsed() >= DURATION {
508 break;
509 }
510
511 v.fetch_add(quota, Ordering::Relaxed);
512 }
513 };
514 let mut handles = vec![];
515 let mut rng = thread_rng();
516 for _ in 0..THREADS {
517 let rate = rng.random_range(10..20);
518 let handle = std::thread::spawn({
519 let v = v.clone();
520 let limiter = lb.clone();
521 move || task(rate, v, limiter)
522 });
523 handles.push(handle);
524 }
525
526 for handle in handles {
527 handle.join().unwrap();
528 }
529
530 let error = (v.load(Ordering::Relaxed) as isize
531 - RATE as isize * DURATION.as_secs() as isize)
532 .unsigned_abs();
533 let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
534 assert!(eratio < ERATIO, "eratio: {}, target: {}", eratio, ERATIO);
535 println!("eratio {eratio} < ERATIO {ERATIO}");
536 }
537
538 #[ignore]
544 #[test]
545 fn test_leak_bucket_overflow() {
546 let v = Arc::new(AtomicU64::new(0));
547 let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
548 let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
549 let start = Instant::now();
550 loop {
551 if start.elapsed() >= DURATION {
552 break;
553 }
554 while let Err(dur) = vs.check(quota) {
555 std::thread::sleep(dur);
556 }
557 if start.elapsed() >= DURATION {
558 break;
559 }
560
561 v.fetch_add(quota, Ordering::Relaxed);
562 }
563 };
564 let mut handles = vec![];
565 let mut rng = thread_rng();
566 for _ in 0..THREADS {
567 let rate = rng.random_range(500..1500);
568 let handle = std::thread::spawn({
569 let v = v.clone();
570 let limiter = lb.clone();
571 move || task(rate, v, limiter)
572 });
573 handles.push(handle);
574 }
575
576 for handle in handles {
577 handle.join().unwrap();
578 }
579
580 let got = v.load(Ordering::Relaxed);
581 let expected = RATE * DURATION.as_secs();
582 let error = (v.load(Ordering::Relaxed) as isize
583 - RATE as isize * DURATION.as_secs() as isize)
584 .unsigned_abs();
585 let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
586 assert!(
587 eratio < ERATIO,
588 "eratio: {}, target: {}, got: {}, expected: {}",
589 eratio,
590 ERATIO,
591 got,
592 expected
593 );
594 println!("eratio {eratio} < ERATIO {ERATIO}");
595 }
596
597 #[tokio::test]
598 async fn test_pause_and_resume() {
599 let l = Arc::new(RateLimiter::new(RateLimit::Pause));
600
601 let delay = l.wait(1);
602
603 let ll = l.clone();
604 tokio::spawn(async move {
605 tokio::time::sleep(Duration::from_millis(100)).await;
606 ll.update(RateLimit::Disabled);
607 });
608
609 tokio::time::sleep(Duration::from_millis(1000)).await;
610 delay.await;
611 }
612}