risingwave_common_rate_limit/
lib.rs1use std::future::Future;
16use std::num::NonZeroU64;
17use std::ops::Deref;
18use std::pin::Pin;
19use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
20use std::sync::{Arc, LazyLock};
21use std::task::{Context, Poll};
22use std::time::{Duration, Instant};
23
24use arc_swap::ArcSwap;
25use parking_lot::Mutex;
26use pin_project_lite::pin_project;
27use risingwave_common::array::DataChunk;
28use risingwave_common::catalog::TableId;
29use risingwave_common::metrics::LabelGuardedUintGaugeVec;
30use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
31use risingwave_common_metrics::{
32 LabelGuardedUintGauge, register_guarded_uint_gauge_vec_with_registry,
33};
34use tokio::sync::oneshot;
35use tokio::time::Sleep;
36
37static METRICS: LazyLock<LabelGuardedUintGaugeVec> = LazyLock::new(|| {
38 register_guarded_uint_gauge_vec_with_registry!(
39 "backfill_rate_limit_bytes",
40 "backfill rate limit bytes per second",
41 &["table_id"],
42 &GLOBAL_METRICS_REGISTRY
43 )
44 .unwrap()
45});
46
47pin_project! {
48 #[derive(Debug)]
49 #[project = DelayProj]
50 pub enum Delay {
51 Noop,
52 Sleep{#[pin] sleep: Sleep},
53 Wait{#[pin] rx: oneshot::Receiver<()> },
54 Infinite,
55 }
56}
57
58impl Delay {
59 pub fn new(duration: Duration) -> Self {
60 match duration {
61 Duration::ZERO => Self::Noop,
62 Duration::MAX => Self::Infinite,
63 dur => Self::Sleep {
64 sleep: tokio::time::sleep(dur),
65 },
66 }
67 }
68}
69
70impl From<Duration> for Delay {
71 fn from(value: Duration) -> Self {
72 Self::new(value)
73 }
74}
75
76impl Future for Delay {
77 type Output = ();
78
79 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
80 match self.project() {
81 DelayProj::Noop => Poll::Ready(()),
82 DelayProj::Sleep { sleep } => sleep.poll(cx),
83 DelayProj::Wait { rx } => rx.poll(cx).map(|_| ()),
84 DelayProj::Infinite => Poll::Pending,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum RateLimit {
92 Disabled,
94 Fixed(NonZeroU64),
96 Pause,
98}
99
100impl RateLimit {
101 pub fn is_paused(&self) -> bool {
103 matches! { self, Self::Pause }
104 }
105
106 pub fn to_u64(self) -> u64 {
107 self.into()
108 }
109}
110
111impl From<RateLimit> for u64 {
112 fn from(rate_limit: RateLimit) -> Self {
113 match rate_limit {
114 RateLimit::Disabled => u64::MAX,
115 RateLimit::Fixed(rate) => rate.get(),
116 RateLimit::Pause => 0,
117 }
118 }
119}
120
121impl From<Option<u32>> for RateLimit {
123 fn from(value: Option<u32>) -> Self {
124 match value {
125 None => Self::Disabled,
126 Some(0) => Self::Pause,
127 Some(rate) => Self::Fixed(unsafe { NonZeroU64::new_unchecked(rate as _) }),
128 }
129 }
130}
131
132#[derive(Debug)]
133pub enum Check {
134 Ok,
135 Retry(Duration),
136 RetryAfter(oneshot::Receiver<()>),
137}
138
139impl Check {
140 pub fn is_ok(&self) -> bool {
141 matches!(self, Self::Ok)
142 }
143}
144
145pub trait RateLimiterTrait: Send + Sync + 'static {
147 fn rate_limit(&self) -> RateLimit;
149
150 fn check(&self, quota: u64) -> Check;
158}
159
160pub struct RateLimiter {
162 inner: ArcSwap<Box<dyn RateLimiterTrait>>,
163}
164
165impl RateLimiter {
166 fn new_inner(rate_limit: RateLimit) -> Box<dyn RateLimiterTrait> {
167 match rate_limit {
168 RateLimit::Disabled => Box::new(InfiniteRatelimiter),
169 RateLimit::Fixed(rate) => Box::new(FixedRateLimiter::new(rate)),
170 RateLimit::Pause => Box::new(PausedRateLimiter::default()),
171 }
172 }
173
174 pub fn new(rate_limit: RateLimit) -> Self {
176 let inner: Box<dyn RateLimiterTrait> = Self::new_inner(rate_limit);
177 let inner = ArcSwap::new(Arc::new(inner));
178 Self { inner }
179 }
180
181 pub fn update(&self, rate_limit: RateLimit) -> RateLimit {
185 let old = self.rate_limit();
186 if self.rate_limit() == rate_limit {
187 return old;
188 }
189 let inner = Self::new_inner(rate_limit);
190 self.inner.store(Arc::new(inner));
191 old
192 }
193
194 pub fn monitored(self, table_id: impl Into<TableId>) -> MonitoredRateLimiter {
196 let metric = METRICS.with_guarded_label_values(&[&table_id.into().to_string()]);
197 let rate_limit = AtomicU64::new(self.rate_limit().to_u64());
198 MonitoredRateLimiter {
199 inner: self,
200 metric,
201 rate_limit,
202 }
203 }
204
205 pub fn rate_limit(&self) -> RateLimit {
206 self.inner.load().rate_limit()
207 }
208
209 pub fn check(&self, quota: u64) -> Check {
210 self.inner.load().check(quota)
211 }
212
213 pub async fn wait(&self, quota: u64) {
214 loop {
215 match self.check(quota) {
216 Check::Ok => return,
217 Check::Retry(duration) => {
218 tokio::time::sleep(duration).await;
219 }
220 Check::RetryAfter(rx) => {
221 let _ = rx.await;
222 }
223 }
224 }
225 }
226
227 pub async fn wait_chunk(&self, chunk: &DataChunk) {
228 self.wait(chunk.rate_limit_permits()).await
229 }
230}
231
232impl RateLimiterTrait for RateLimiter {
233 fn rate_limit(&self) -> RateLimit {
235 self.rate_limit()
236 }
237
238 fn check(&self, quota: u64) -> Check {
246 self.check(quota)
247 }
248}
249
250pub struct MonitoredRateLimiter {
252 inner: RateLimiter,
253 metric: LabelGuardedUintGauge,
254 rate_limit: AtomicU64,
255}
256
257impl Deref for MonitoredRateLimiter {
258 type Target = RateLimiter;
259
260 fn deref(&self) -> &Self::Target {
261 &self.inner
262 }
263}
264
265impl RateLimiterTrait for MonitoredRateLimiter {
266 fn rate_limit(&self) -> RateLimit {
267 self.inner.rate_limit()
268 }
269
270 fn check(&self, quota: u64) -> Check {
271 let check = self.inner.check(quota);
272 if matches! { check, Check::Ok} {
273 self.report();
274 }
275 check
276 }
277}
278
279impl MonitoredRateLimiter {
280 fn report(&self) {
284 let rate_limit = self.inner.rate_limit().to_u64();
285 if rate_limit != self.rate_limit.load(Ordering::Relaxed) {
286 self.rate_limit.store(rate_limit, Ordering::Relaxed);
287 self.metric.set(rate_limit);
288 }
289 }
290}
291
292#[derive(Debug)]
293pub struct InfiniteRatelimiter;
294
295impl RateLimiterTrait for InfiniteRatelimiter {
296 fn rate_limit(&self) -> RateLimit {
297 RateLimit::Disabled
298 }
299
300 fn check(&self, _: u64) -> Check {
301 Check::Ok
302 }
303}
304
305#[derive(Debug)]
306pub struct PausedRateLimiter {
307 waiters: Mutex<Vec<oneshot::Sender<()>>>,
308}
309
310impl Default for PausedRateLimiter {
311 fn default() -> Self {
312 Self {
313 waiters: Mutex::new(vec![]),
314 }
315 }
316}
317
318impl Drop for PausedRateLimiter {
319 fn drop(&mut self) {
320 for tx in self.waiters.lock().drain(..) {
321 let _ = tx.send(());
322 }
323 }
324}
325
326impl RateLimiterTrait for PausedRateLimiter {
327 fn rate_limit(&self) -> RateLimit {
328 RateLimit::Pause
329 }
330
331 fn check(&self, _: u64) -> Check {
332 let (tx, rx) = oneshot::channel();
333 self.waiters.lock().push(tx);
334 Check::RetryAfter(rx)
335 }
336}
337
338#[derive(Debug)]
339pub struct FixedRateLimiter {
340 inner: LeakBucket,
341 rate: NonZeroU64,
342}
343
344impl FixedRateLimiter {
345 pub fn new(rate: NonZeroU64) -> Self {
346 let inner = LeakBucket::new(rate);
347 Self { inner, rate }
348 }
349}
350
351impl RateLimiterTrait for FixedRateLimiter {
352 fn rate_limit(&self) -> RateLimit {
353 RateLimit::Fixed(self.rate)
354 }
355
356 fn check(&self, quota: u64) -> Check {
357 match self.inner.check(quota) {
358 Ok(()) => Check::Ok,
359 Err(duration) => Check::Retry(duration),
360 }
361 }
362}
363
364#[derive(Debug)]
366pub struct LeakBucket {
367 scale: AtomicU64,
373
374 ltat: AtomicU64,
376
377 origin: Instant,
379
380 total_allowed_quotas: AtomicU64,
382 total_waited_nanos: AtomicI64,
384}
385
386impl LeakBucket {
387 const NANO: u64 = Duration::from_secs(1).as_nanos() as u64;
388
389 fn scale(rate: NonZeroU64) -> u64 {
391 std::cmp::max(Self::NANO / rate.get(), 1)
392 }
393
394 fn new(rate: NonZeroU64) -> Self {
396 let scale = Self::scale(rate);
397
398 let origin = Instant::now();
399 let scale = AtomicU64::new(scale);
400
401 Self {
402 scale,
403 ltat: AtomicU64::new(0),
404 origin,
405 total_allowed_quotas: AtomicU64::new(0),
406 total_waited_nanos: AtomicI64::new(0),
407 }
408 }
409
410 fn check(&self, quota: u64) -> Result<(), Duration> {
416 let now = Instant::now();
417 let tnow = now.duration_since(self.origin).as_nanos() as u64;
418
419 let weight = quota * self.scale.load(Ordering::Relaxed);
420
421 let mut ltat = self.ltat.load(Ordering::Acquire);
422 let tat = loop {
423 let tat = ltat + weight;
424
425 if tat > tnow {
426 self.total_waited_nanos
427 .fetch_add((tat - tnow) as i64, Ordering::Relaxed);
428 return Err(Duration::from_nanos(tat - tnow));
429 }
430
431 let ltat_new = std::cmp::max(tat, tnow);
432
433 match self
434 .ltat
435 .compare_exchange(ltat, ltat_new, Ordering::Release, Ordering::Acquire)
436 {
437 Ok(_) => break tat,
438 Err(cur) => ltat = cur,
439 }
440 };
441
442 self.total_allowed_quotas
443 .fetch_add(quota, Ordering::Relaxed);
444 self.total_waited_nanos
445 .fetch_sub((tnow - tat) as i64, Ordering::Relaxed);
446
447 Ok(())
448 }
449
450 fn _avg_wait_nanos_per_quota(&self) -> i64 {
455 let quotas = self.total_allowed_quotas.load(Ordering::Relaxed);
456 if quotas == 0 {
457 0
458 } else {
459 let nanos = self.total_waited_nanos.load(Ordering::Relaxed);
460 nanos / quotas as i64
461 }
462 }
463
464 fn _reset_stats(&self) {
467 self.total_allowed_quotas.store(0, Ordering::Relaxed);
468 self.total_waited_nanos.store(0, Ordering::Relaxed);
469 }
470
471 fn _update(&self, rate: NonZeroU64) {
474 let scale = Self::scale(rate);
475 self.scale.store(scale, Ordering::Relaxed);
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use std::sync::Arc;
482 use std::sync::atomic::Ordering;
483
484 use rand::{Rng, rng as thread_rng};
485
486 use super::*;
487
488 const ERATIO: f64 = 0.05;
489 const THREADS: usize = 8;
490 const RATE: u64 = 1000;
491 const DURATION: Duration = Duration::from_secs(10);
492
493 #[ignore]
499 #[test]
500 fn test_leak_bucket() {
501 let v = Arc::new(AtomicU64::new(0));
502 let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
503 let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
504 let start = Instant::now();
505 loop {
506 if start.elapsed() >= DURATION {
507 break;
508 }
509 while let Err(dur) = vs.check(quota) {
510 std::thread::sleep(dur);
511 }
512 if start.elapsed() >= DURATION {
513 break;
514 }
515
516 v.fetch_add(quota, Ordering::Relaxed);
517 }
518 };
519 let mut handles = vec![];
520 let mut rng = thread_rng();
521 for _ in 0..THREADS {
522 let rate = rng.random_range(10..20);
523 let handle = std::thread::spawn({
524 let v = v.clone();
525 let limiter = lb.clone();
526 move || task(rate, v, limiter)
527 });
528 handles.push(handle);
529 }
530
531 for handle in handles {
532 handle.join().unwrap();
533 }
534
535 let error = (v.load(Ordering::Relaxed) as isize
536 - RATE as isize * DURATION.as_secs() as isize)
537 .unsigned_abs();
538 let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
539 assert!(eratio < ERATIO, "eratio: {}, target: {}", eratio, ERATIO);
540 println!("eratio {eratio} < ERATIO {ERATIO}");
541 }
542
543 #[ignore]
549 #[test]
550 fn test_leak_bucket_overflow() {
551 let v = Arc::new(AtomicU64::new(0));
552 let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
553 let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
554 let start = Instant::now();
555 loop {
556 if start.elapsed() >= DURATION {
557 break;
558 }
559 while let Err(dur) = vs.check(quota) {
560 std::thread::sleep(dur);
561 }
562 if start.elapsed() >= DURATION {
563 break;
564 }
565
566 v.fetch_add(quota, Ordering::Relaxed);
567 }
568 };
569 let mut handles = vec![];
570 let mut rng = thread_rng();
571 for _ in 0..THREADS {
572 let rate = rng.random_range(500..1500);
573 let handle = std::thread::spawn({
574 let v = v.clone();
575 let limiter = lb.clone();
576 move || task(rate, v, limiter)
577 });
578 handles.push(handle);
579 }
580
581 for handle in handles {
582 handle.join().unwrap();
583 }
584
585 let got = v.load(Ordering::Relaxed);
586 let expected = RATE * DURATION.as_secs();
587 let error = (v.load(Ordering::Relaxed) as isize
588 - RATE as isize * DURATION.as_secs() as isize)
589 .unsigned_abs();
590 let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
591 assert!(
592 eratio < ERATIO,
593 "eratio: {}, target: {}, got: {}, expected: {}",
594 eratio,
595 ERATIO,
596 got,
597 expected
598 );
599 println!("eratio {eratio} < ERATIO {ERATIO}");
600 }
601
602 #[tokio::test]
603 async fn test_pause_and_resume() {
604 let l = Arc::new(RateLimiter::new(RateLimit::Pause));
605
606 let delay = l.wait(1);
607
608 let ll = l.clone();
609 tokio::spawn(async move {
610 tokio::time::sleep(Duration::from_millis(100)).await;
611 ll.update(RateLimit::Disabled);
612 });
613
614 tokio::time::sleep(Duration::from_millis(1000)).await;
615 delay.await;
616 }
617}