risingwave_common_rate_limit/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::num::NonZeroU64;
17use std::ops::Deref;
18use std::pin::Pin;
19use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
20use std::sync::{Arc, LazyLock};
21use std::task::{Context, Poll};
22use std::time::{Duration, Instant};
23
24use arc_swap::ArcSwap;
25use parking_lot::Mutex;
26use pin_project_lite::pin_project;
27use risingwave_common::catalog::TableId;
28use risingwave_common::metrics::LabelGuardedUintGaugeVec;
29use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
30use risingwave_common_metrics::{
31    LabelGuardedUintGauge, register_guarded_uint_gauge_vec_with_registry,
32};
33use tokio::sync::oneshot;
34use tokio::time::Sleep;
35
36static METRICS: LazyLock<LabelGuardedUintGaugeVec<1>> = LazyLock::new(|| {
37    register_guarded_uint_gauge_vec_with_registry!(
38        "backfill_rate_limit_bytes",
39        "backfill rate limit bytes per second",
40        &["table_id"],
41        &GLOBAL_METRICS_REGISTRY
42    )
43    .unwrap()
44});
45
46pin_project! {
47    #[derive(Debug)]
48    #[project = DelayProj]
49    pub enum Delay {
50        Noop,
51        Sleep{#[pin] sleep: Sleep},
52        Wait{#[pin] rx: oneshot::Receiver<()> },
53        Infinite,
54    }
55}
56
57impl Delay {
58    pub fn new(duration: Duration) -> Self {
59        match duration {
60            Duration::ZERO => Self::Noop,
61            Duration::MAX => Self::Infinite,
62            dur => Self::Sleep {
63                sleep: tokio::time::sleep(dur),
64            },
65        }
66    }
67}
68
69impl From<Duration> for Delay {
70    fn from(value: Duration) -> Self {
71        Self::new(value)
72    }
73}
74
75impl Future for Delay {
76    type Output = ();
77
78    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79        match self.project() {
80            DelayProj::Noop => Poll::Ready(()),
81            DelayProj::Sleep { sleep } => sleep.poll(cx),
82            DelayProj::Wait { rx } => rx.poll(cx).map(|_| ()),
83            DelayProj::Infinite => Poll::Pending,
84        }
85    }
86}
87
88/// Rate limit policy.
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum RateLimit {
91    /// Rate limit disabled.
92    Disabled,
93    /// Rate limit with fixed rate.
94    Fixed(NonZeroU64),
95    /// Pause with 0 rate.
96    Pause,
97}
98
99impl RateLimit {
100    /// Return if the rate limit is set to pause policy.
101    pub fn is_paused(&self) -> bool {
102        matches! { self, Self::Pause }
103    }
104
105    pub fn to_u64(self) -> u64 {
106        self.into()
107    }
108}
109
110impl From<RateLimit> for u64 {
111    fn from(rate_limit: RateLimit) -> Self {
112        match rate_limit {
113            RateLimit::Disabled => u64::MAX,
114            RateLimit::Fixed(rate) => rate.get(),
115            RateLimit::Pause => 0,
116        }
117    }
118}
119
120// Adapt to the old rate limit policy.
121impl From<Option<u32>> for RateLimit {
122    fn from(value: Option<u32>) -> Self {
123        match value {
124            None => Self::Disabled,
125            Some(0) => Self::Pause,
126            Some(rate) => Self::Fixed(unsafe { NonZeroU64::new_unchecked(rate as _) }),
127        }
128    }
129}
130
131#[derive(Debug)]
132pub enum Check {
133    Ok,
134    Retry(Duration),
135    RetryAfter(oneshot::Receiver<()>),
136}
137
138impl Check {
139    pub fn is_ok(&self) -> bool {
140        matches!(self, Self::Ok)
141    }
142}
143
144/// Shared behavior for rate limiters.
145pub trait RateLimiterTrait: Send + Sync + 'static {
146    /// Return current throttle policy.
147    fn rate_limit(&self) -> RateLimit;
148
149    /// Check if the request with the given quota is supposed to be allowed at the moment.
150    ///
151    /// On success, the quota will be consumed. [`Check::Ok`] is returned.
152    /// The caller is supposed to proceed the request with the given quota.
153    ///
154    /// On failure, [`Check::Retry`] or [`Check::RetryAfter`] is returned.
155    /// The caller is supposed to retry the check after the given duration or retry after receiving the signal.
156    fn check(&self, quota: u64) -> Check;
157}
158
159/// A rate limiter that supports multiple rate limit policy and online policy switch.
160pub struct RateLimiter {
161    inner: ArcSwap<Box<dyn RateLimiterTrait>>,
162}
163
164impl RateLimiter {
165    fn new_inner(rate_limit: RateLimit) -> Box<dyn RateLimiterTrait> {
166        match rate_limit {
167            RateLimit::Disabled => Box::new(InfiniteRatelimiter),
168            RateLimit::Fixed(rate) => Box::new(FixedRateLimiter::new(rate)),
169            RateLimit::Pause => Box::new(PausedRateLimiter::default()),
170        }
171    }
172
173    /// Create a new rate limiter with given rate limit policy.
174    pub fn new(rate_limit: RateLimit) -> Self {
175        let inner: Box<dyn RateLimiterTrait> = Self::new_inner(rate_limit);
176        let inner = ArcSwap::new(Arc::new(inner));
177        Self { inner }
178    }
179
180    /// Update rate limit policy of the rate limiter.
181    ///
182    /// Returns the old rate limit policy.
183    pub fn update(&self, rate_limit: RateLimit) -> RateLimit {
184        let old = self.rate_limit();
185        if self.rate_limit() == rate_limit {
186            return old;
187        }
188        let inner = Self::new_inner(rate_limit);
189        self.inner.store(Arc::new(inner));
190        old
191    }
192
193    /// Monitor the rate limiter with related table id.
194    pub fn monitored(self, table_id: impl Into<TableId>) -> MonitoredRateLimiter {
195        let metric = METRICS.with_guarded_label_values(&[&table_id.into().to_string()]);
196        let rate_limit = AtomicU64::new(self.rate_limit().to_u64());
197        MonitoredRateLimiter {
198            inner: self,
199            metric,
200            rate_limit,
201        }
202    }
203
204    pub fn rate_limit(&self) -> RateLimit {
205        self.inner.load().rate_limit()
206    }
207
208    pub fn check(&self, quota: u64) -> Check {
209        self.inner.load().check(quota)
210    }
211
212    pub async fn wait(&self, quota: u64) {
213        loop {
214            match self.check(quota) {
215                Check::Ok => return,
216                Check::Retry(duration) => {
217                    tokio::time::sleep(duration).await;
218                }
219                Check::RetryAfter(rx) => {
220                    let _ = rx.await;
221                }
222            }
223        }
224    }
225}
226
227impl RateLimiterTrait for RateLimiter {
228    /// Return current throttle policy.
229    fn rate_limit(&self) -> RateLimit {
230        self.rate_limit()
231    }
232
233    /// Check if the request with the given quota is supposed to be allowed at the moment.
234    ///
235    /// On success, the quota will be consumed. [`Check::Ok`] is returned.
236    /// The caller is supposed to proceed the request with the given quota.
237    ///
238    /// On failure, [`Check::Retry`] or [`Check::RetryAfter`] is returned.
239    /// The caller is supposed to retry the check after the given duration or retry after receiving the signal.
240    fn check(&self, quota: u64) -> Check {
241        self.check(quota)
242    }
243}
244
245/// A rate limiter that supports multiple rate limit policy, online policy switch and metrics support.
246pub struct MonitoredRateLimiter {
247    inner: RateLimiter,
248    metric: LabelGuardedUintGauge<1>,
249    rate_limit: AtomicU64,
250}
251
252impl Deref for MonitoredRateLimiter {
253    type Target = RateLimiter;
254
255    fn deref(&self) -> &Self::Target {
256        &self.inner
257    }
258}
259
260impl RateLimiterTrait for MonitoredRateLimiter {
261    fn rate_limit(&self) -> RateLimit {
262        self.inner.rate_limit()
263    }
264
265    fn check(&self, quota: u64) -> Check {
266        let check = self.inner.check(quota);
267        if matches! { check, Check::Ok} {
268            self.report();
269        }
270        check
271    }
272}
273
274impl MonitoredRateLimiter {
275    /// Report the rate limit policy to the metric if updated.
276    ///
277    /// `report` is called automatically by each `until` call.
278    fn report(&self) {
279        let rate_limit = self.inner.rate_limit().to_u64();
280        if rate_limit != self.rate_limit.load(Ordering::Relaxed) {
281            self.rate_limit.store(rate_limit, Ordering::Relaxed);
282            self.metric.set(rate_limit);
283        }
284    }
285}
286
287#[derive(Debug)]
288pub struct InfiniteRatelimiter;
289
290impl RateLimiterTrait for InfiniteRatelimiter {
291    fn rate_limit(&self) -> RateLimit {
292        RateLimit::Disabled
293    }
294
295    fn check(&self, _: u64) -> Check {
296        Check::Ok
297    }
298}
299
300#[derive(Debug)]
301pub struct PausedRateLimiter {
302    waiters: Mutex<Vec<oneshot::Sender<()>>>,
303}
304
305impl Default for PausedRateLimiter {
306    fn default() -> Self {
307        Self {
308            waiters: Mutex::new(vec![]),
309        }
310    }
311}
312
313impl Drop for PausedRateLimiter {
314    fn drop(&mut self) {
315        for tx in self.waiters.lock().drain(..) {
316            let _ = tx.send(());
317        }
318    }
319}
320
321impl RateLimiterTrait for PausedRateLimiter {
322    fn rate_limit(&self) -> RateLimit {
323        RateLimit::Pause
324    }
325
326    fn check(&self, _: u64) -> Check {
327        let (tx, rx) = oneshot::channel();
328        self.waiters.lock().push(tx);
329        Check::RetryAfter(rx)
330    }
331}
332
333#[derive(Debug)]
334pub struct FixedRateLimiter {
335    inner: LeakBucket,
336    rate: NonZeroU64,
337}
338
339impl FixedRateLimiter {
340    pub fn new(rate: NonZeroU64) -> Self {
341        let inner = LeakBucket::new(rate);
342        Self { inner, rate }
343    }
344}
345
346impl RateLimiterTrait for FixedRateLimiter {
347    fn rate_limit(&self) -> RateLimit {
348        RateLimit::Fixed(self.rate)
349    }
350
351    fn check(&self, quota: u64) -> Check {
352        match self.inner.check(quota) {
353            Ok(()) => Check::Ok,
354            Err(duration) => Check::Retry(duration),
355        }
356    }
357}
358
359/// A GCRA-like leak bucket visual scheduler that never deny request even whose weight is larger than tau and only count TAT.
360#[derive(Debug)]
361pub struct LeakBucket {
362    /// Weight scale per 1.0 unit quota in nanosecond.
363    ///
364    /// scale is always non-zero.
365    ///
366    /// scale = rate / 1 (in second)
367    scale: AtomicU64,
368
369    /// Last request's TAT (Theoretical Arrival Time) in nanosecond.
370    ltat: AtomicU64,
371
372    /// Zero time instant.
373    origin: Instant,
374
375    /// Total allowed quotas.
376    total_allowed_quotas: AtomicU64,
377    /// Total waited nanos.
378    total_waited_nanos: AtomicI64,
379}
380
381impl LeakBucket {
382    const NANO: u64 = Duration::from_secs(1).as_nanos() as u64;
383
384    /// calculate the weight scale per 1.0 unit quota in nanosecond.
385    fn scale(rate: NonZeroU64) -> u64 {
386        std::cmp::max(Self::NANO / rate.get(), 1)
387    }
388
389    /// Create a new GCRA-like leak bucket visual scheduler with given rate.
390    fn new(rate: NonZeroU64) -> Self {
391        let scale = Self::scale(rate);
392
393        let origin = Instant::now();
394        let scale = AtomicU64::new(scale);
395
396        Self {
397            scale,
398            ltat: AtomicU64::new(0),
399            origin,
400            total_allowed_quotas: AtomicU64::new(0),
401            total_waited_nanos: AtomicI64::new(0),
402        }
403    }
404
405    /// Check if the request with the given quota is supposed to be allowed at the moment.
406    ///
407    /// On success, the quota will be consumed. The caller is supposed to proceed the quota.
408    ///
409    /// On failure, the minimal duration to retry `check()` is returned.
410    fn check(&self, quota: u64) -> Result<(), Duration> {
411        let now = Instant::now();
412        let tnow = now.duration_since(self.origin).as_nanos() as u64;
413
414        let weight = quota * self.scale.load(Ordering::Relaxed);
415
416        let mut ltat = self.ltat.load(Ordering::Acquire);
417        let tat = loop {
418            let tat = ltat + weight;
419
420            if tat > tnow {
421                self.total_waited_nanos
422                    .fetch_add((tat - tnow) as i64, Ordering::Relaxed);
423                return Err(Duration::from_nanos(tat - tnow));
424            }
425
426            let ltat_new = std::cmp::max(tat, tnow);
427
428            match self
429                .ltat
430                .compare_exchange(ltat, ltat_new, Ordering::Release, Ordering::Acquire)
431            {
432                Ok(_) => break tat,
433                Err(cur) => ltat = cur,
434            }
435        };
436
437        self.total_allowed_quotas
438            .fetch_add(quota, Ordering::Relaxed);
439        self.total_waited_nanos
440            .fetch_sub((tnow - tat) as i64, Ordering::Relaxed);
441
442        Ok(())
443    }
444
445    // // TODO(MrCroxx): Reserved for adaptive rate limiter.
446    /// Average wait time per quota.
447    ///
448    /// Positive value indicates waits, negative value indicates there is spare rate limit.
449    fn _avg_wait_nanos_per_quota(&self) -> i64 {
450        let quotas = self.total_allowed_quotas.load(Ordering::Relaxed);
451        if quotas == 0 {
452            0
453        } else {
454            let nanos = self.total_waited_nanos.load(Ordering::Relaxed);
455            nanos / quotas as i64
456        }
457    }
458
459    // // TODO(MrCroxx): Reserved for adaptive rate limiter.
460    /// Reset statistics.
461    fn _reset_stats(&self) {
462        self.total_allowed_quotas.store(0, Ordering::Relaxed);
463        self.total_waited_nanos.store(0, Ordering::Relaxed);
464    }
465
466    // TODO(MrCroxx): Reserved for adaptive rate limiter.
467    /// Update rate limit with the given rate.
468    fn _update(&self, rate: NonZeroU64) {
469        let scale = Self::scale(rate);
470        self.scale.store(scale, Ordering::Relaxed);
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use std::sync::Arc;
477    use std::sync::atomic::Ordering;
478
479    use rand::{Rng, rng as thread_rng};
480
481    use super::*;
482
483    const ERATIO: f64 = 0.05;
484    const THREADS: usize = 8;
485    const RATE: u64 = 1000;
486    const DURATION: Duration = Duration::from_secs(10);
487
488    /// To run this test:
489    ///
490    /// ```bash
491    /// cargo test --package risingwave_common_rate_limit --lib -- tests::test_leak_bucket --exact --show-output --ignored
492    /// ```
493    #[ignore]
494    #[test]
495    fn test_leak_bucket() {
496        let v = Arc::new(AtomicU64::new(0));
497        let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
498        let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
499            let start = Instant::now();
500            loop {
501                if start.elapsed() >= DURATION {
502                    break;
503                }
504                while let Err(dur) = vs.check(quota) {
505                    std::thread::sleep(dur);
506                }
507                if start.elapsed() >= DURATION {
508                    break;
509                }
510
511                v.fetch_add(quota, Ordering::Relaxed);
512            }
513        };
514        let mut handles = vec![];
515        let mut rng = thread_rng();
516        for _ in 0..THREADS {
517            let rate = rng.random_range(10..20);
518            let handle = std::thread::spawn({
519                let v = v.clone();
520                let limiter = lb.clone();
521                move || task(rate, v, limiter)
522            });
523            handles.push(handle);
524        }
525
526        for handle in handles {
527            handle.join().unwrap();
528        }
529
530        let error = (v.load(Ordering::Relaxed) as isize
531            - RATE as isize * DURATION.as_secs() as isize)
532            .unsigned_abs();
533        let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
534        assert!(eratio < ERATIO, "eratio: {}, target: {}", eratio, ERATIO);
535        println!("eratio {eratio} < ERATIO {ERATIO}");
536    }
537
538    /// To run this test:
539    ///
540    /// ```bash
541    /// cargo test --package risingwave_common_rate_limit --lib -- tests::test_leak_bucket_overflow --exact --show-output --ignored
542    /// ```
543    #[ignore]
544    #[test]
545    fn test_leak_bucket_overflow() {
546        let v = Arc::new(AtomicU64::new(0));
547        let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
548        let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
549            let start = Instant::now();
550            loop {
551                if start.elapsed() >= DURATION {
552                    break;
553                }
554                while let Err(dur) = vs.check(quota) {
555                    std::thread::sleep(dur);
556                }
557                if start.elapsed() >= DURATION {
558                    break;
559                }
560
561                v.fetch_add(quota, Ordering::Relaxed);
562            }
563        };
564        let mut handles = vec![];
565        let mut rng = thread_rng();
566        for _ in 0..THREADS {
567            let rate = rng.random_range(500..1500);
568            let handle = std::thread::spawn({
569                let v = v.clone();
570                let limiter = lb.clone();
571                move || task(rate, v, limiter)
572            });
573            handles.push(handle);
574        }
575
576        for handle in handles {
577            handle.join().unwrap();
578        }
579
580        let got = v.load(Ordering::Relaxed);
581        let expected = RATE * DURATION.as_secs();
582        let error = (v.load(Ordering::Relaxed) as isize
583            - RATE as isize * DURATION.as_secs() as isize)
584            .unsigned_abs();
585        let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
586        assert!(
587            eratio < ERATIO,
588            "eratio: {}, target: {}, got: {}, expected: {}",
589            eratio,
590            ERATIO,
591            got,
592            expected
593        );
594        println!("eratio {eratio} < ERATIO {ERATIO}");
595    }
596
597    #[tokio::test]
598    async fn test_pause_and_resume() {
599        let l = Arc::new(RateLimiter::new(RateLimit::Pause));
600
601        let delay = l.wait(1);
602
603        let ll = l.clone();
604        tokio::spawn(async move {
605            tokio::time::sleep(Duration::from_millis(100)).await;
606            ll.update(RateLimit::Disabled);
607        });
608
609        tokio::time::sleep(Duration::from_millis(1000)).await;
610        delay.await;
611    }
612}