risingwave_common_rate_limit/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::num::NonZeroU64;
17use std::ops::Deref;
18use std::pin::Pin;
19use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
20use std::sync::{Arc, LazyLock};
21use std::task::{Context, Poll};
22use std::time::{Duration, Instant};
23
24use arc_swap::ArcSwap;
25use parking_lot::Mutex;
26use pin_project_lite::pin_project;
27use risingwave_common::array::DataChunk;
28use risingwave_common::catalog::TableId;
29use risingwave_common::metrics::LabelGuardedUintGaugeVec;
30use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
31use risingwave_common_metrics::{
32    LabelGuardedUintGauge, register_guarded_uint_gauge_vec_with_registry,
33};
34use tokio::sync::oneshot;
35use tokio::time::Sleep;
36
37static METRICS: LazyLock<LabelGuardedUintGaugeVec> = LazyLock::new(|| {
38    register_guarded_uint_gauge_vec_with_registry!(
39        "backfill_rate_limit_bytes",
40        "backfill rate limit bytes per second",
41        &["table_id"],
42        &GLOBAL_METRICS_REGISTRY
43    )
44    .unwrap()
45});
46
47pin_project! {
48    #[derive(Debug)]
49    #[project = DelayProj]
50    pub enum Delay {
51        Noop,
52        Sleep{#[pin] sleep: Sleep},
53        Wait{#[pin] rx: oneshot::Receiver<()> },
54        Infinite,
55    }
56}
57
58impl Delay {
59    pub fn new(duration: Duration) -> Self {
60        match duration {
61            Duration::ZERO => Self::Noop,
62            Duration::MAX => Self::Infinite,
63            dur => Self::Sleep {
64                sleep: tokio::time::sleep(dur),
65            },
66        }
67    }
68}
69
70impl From<Duration> for Delay {
71    fn from(value: Duration) -> Self {
72        Self::new(value)
73    }
74}
75
76impl Future for Delay {
77    type Output = ();
78
79    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
80        match self.project() {
81            DelayProj::Noop => Poll::Ready(()),
82            DelayProj::Sleep { sleep } => sleep.poll(cx),
83            DelayProj::Wait { rx } => rx.poll(cx).map(|_| ()),
84            DelayProj::Infinite => Poll::Pending,
85        }
86    }
87}
88
89/// Rate limit policy.
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum RateLimit {
92    /// Rate limit disabled.
93    Disabled,
94    /// Rate limit with fixed rate.
95    Fixed(NonZeroU64),
96    /// Pause with 0 rate.
97    Pause,
98}
99
100impl RateLimit {
101    /// Return if the rate limit is set to pause policy.
102    pub fn is_paused(&self) -> bool {
103        matches! { self, Self::Pause }
104    }
105
106    pub fn to_u64(self) -> u64 {
107        self.into()
108    }
109}
110
111impl From<RateLimit> for u64 {
112    fn from(rate_limit: RateLimit) -> Self {
113        match rate_limit {
114            RateLimit::Disabled => u64::MAX,
115            RateLimit::Fixed(rate) => rate.get(),
116            RateLimit::Pause => 0,
117        }
118    }
119}
120
121// Adapt to the old rate limit policy.
122impl From<Option<u32>> for RateLimit {
123    fn from(value: Option<u32>) -> Self {
124        match value {
125            None => Self::Disabled,
126            Some(0) => Self::Pause,
127            Some(rate) => Self::Fixed(unsafe { NonZeroU64::new_unchecked(rate as _) }),
128        }
129    }
130}
131
132#[derive(Debug)]
133pub enum Check {
134    Ok,
135    Retry(Duration),
136    RetryAfter(oneshot::Receiver<()>),
137}
138
139impl Check {
140    pub fn is_ok(&self) -> bool {
141        matches!(self, Self::Ok)
142    }
143}
144
145/// Shared behavior for rate limiters.
146pub trait RateLimiterTrait: Send + Sync + 'static {
147    /// Return current throttle policy.
148    fn rate_limit(&self) -> RateLimit;
149
150    /// Check if the request with the given quota is supposed to be allowed at the moment.
151    ///
152    /// On success, the quota will be consumed. [`Check::Ok`] is returned.
153    /// The caller is supposed to proceed the request with the given quota.
154    ///
155    /// On failure, [`Check::Retry`] or [`Check::RetryAfter`] is returned.
156    /// The caller is supposed to retry the check after the given duration or retry after receiving the signal.
157    fn check(&self, quota: u64) -> Check;
158}
159
160/// A rate limiter that supports multiple rate limit policy and online policy switch.
161pub struct RateLimiter {
162    inner: ArcSwap<Box<dyn RateLimiterTrait>>,
163}
164
165impl RateLimiter {
166    fn new_inner(rate_limit: RateLimit) -> Box<dyn RateLimiterTrait> {
167        match rate_limit {
168            RateLimit::Disabled => Box::new(InfiniteRatelimiter),
169            RateLimit::Fixed(rate) => Box::new(FixedRateLimiter::new(rate)),
170            RateLimit::Pause => Box::new(PausedRateLimiter::default()),
171        }
172    }
173
174    /// Create a new rate limiter with given rate limit policy.
175    pub fn new(rate_limit: RateLimit) -> Self {
176        let inner: Box<dyn RateLimiterTrait> = Self::new_inner(rate_limit);
177        let inner = ArcSwap::new(Arc::new(inner));
178        Self { inner }
179    }
180
181    /// Update rate limit policy of the rate limiter.
182    ///
183    /// Returns the old rate limit policy.
184    pub fn update(&self, rate_limit: RateLimit) -> RateLimit {
185        let old = self.rate_limit();
186        if self.rate_limit() == rate_limit {
187            return old;
188        }
189        let inner = Self::new_inner(rate_limit);
190        self.inner.store(Arc::new(inner));
191        old
192    }
193
194    /// Monitor the rate limiter with related table id.
195    pub fn monitored(self, table_id: impl Into<TableId>) -> MonitoredRateLimiter {
196        let metric = METRICS.with_guarded_label_values(&[&table_id.into().to_string()]);
197        let rate_limit = AtomicU64::new(self.rate_limit().to_u64());
198        MonitoredRateLimiter {
199            inner: self,
200            metric,
201            rate_limit,
202        }
203    }
204
205    pub fn rate_limit(&self) -> RateLimit {
206        self.inner.load().rate_limit()
207    }
208
209    pub fn check(&self, quota: u64) -> Check {
210        self.inner.load().check(quota)
211    }
212
213    pub async fn wait(&self, quota: u64) {
214        loop {
215            match self.check(quota) {
216                Check::Ok => return,
217                Check::Retry(duration) => {
218                    tokio::time::sleep(duration).await;
219                }
220                Check::RetryAfter(rx) => {
221                    let _ = rx.await;
222                }
223            }
224        }
225    }
226
227    pub async fn wait_chunk(&self, chunk: &DataChunk) {
228        self.wait(chunk.rate_limit_permits()).await
229    }
230}
231
232impl RateLimiterTrait for RateLimiter {
233    /// Return current throttle policy.
234    fn rate_limit(&self) -> RateLimit {
235        self.rate_limit()
236    }
237
238    /// Check if the request with the given quota is supposed to be allowed at the moment.
239    ///
240    /// On success, the quota will be consumed. [`Check::Ok`] is returned.
241    /// The caller is supposed to proceed the request with the given quota.
242    ///
243    /// On failure, [`Check::Retry`] or [`Check::RetryAfter`] is returned.
244    /// The caller is supposed to retry the check after the given duration or retry after receiving the signal.
245    fn check(&self, quota: u64) -> Check {
246        self.check(quota)
247    }
248}
249
250/// A rate limiter that supports multiple rate limit policy, online policy switch and metrics support.
251pub struct MonitoredRateLimiter {
252    inner: RateLimiter,
253    metric: LabelGuardedUintGauge,
254    rate_limit: AtomicU64,
255}
256
257impl Deref for MonitoredRateLimiter {
258    type Target = RateLimiter;
259
260    fn deref(&self) -> &Self::Target {
261        &self.inner
262    }
263}
264
265impl RateLimiterTrait for MonitoredRateLimiter {
266    fn rate_limit(&self) -> RateLimit {
267        self.inner.rate_limit()
268    }
269
270    fn check(&self, quota: u64) -> Check {
271        let check = self.inner.check(quota);
272        if matches! { check, Check::Ok} {
273            self.report();
274        }
275        check
276    }
277}
278
279impl MonitoredRateLimiter {
280    /// Report the rate limit policy to the metric if updated.
281    ///
282    /// `report` is called automatically by each `until` call.
283    fn report(&self) {
284        let rate_limit = self.inner.rate_limit().to_u64();
285        if rate_limit != self.rate_limit.load(Ordering::Relaxed) {
286            self.rate_limit.store(rate_limit, Ordering::Relaxed);
287            self.metric.set(rate_limit);
288        }
289    }
290}
291
292#[derive(Debug)]
293pub struct InfiniteRatelimiter;
294
295impl RateLimiterTrait for InfiniteRatelimiter {
296    fn rate_limit(&self) -> RateLimit {
297        RateLimit::Disabled
298    }
299
300    fn check(&self, _: u64) -> Check {
301        Check::Ok
302    }
303}
304
305#[derive(Debug)]
306pub struct PausedRateLimiter {
307    waiters: Mutex<Vec<oneshot::Sender<()>>>,
308}
309
310impl Default for PausedRateLimiter {
311    fn default() -> Self {
312        Self {
313            waiters: Mutex::new(vec![]),
314        }
315    }
316}
317
318impl Drop for PausedRateLimiter {
319    fn drop(&mut self) {
320        for tx in self.waiters.lock().drain(..) {
321            let _ = tx.send(());
322        }
323    }
324}
325
326impl RateLimiterTrait for PausedRateLimiter {
327    fn rate_limit(&self) -> RateLimit {
328        RateLimit::Pause
329    }
330
331    fn check(&self, _: u64) -> Check {
332        let (tx, rx) = oneshot::channel();
333        self.waiters.lock().push(tx);
334        Check::RetryAfter(rx)
335    }
336}
337
338#[derive(Debug)]
339pub struct FixedRateLimiter {
340    inner: LeakBucket,
341    rate: NonZeroU64,
342}
343
344impl FixedRateLimiter {
345    pub fn new(rate: NonZeroU64) -> Self {
346        let inner = LeakBucket::new(rate);
347        Self { inner, rate }
348    }
349}
350
351impl RateLimiterTrait for FixedRateLimiter {
352    fn rate_limit(&self) -> RateLimit {
353        RateLimit::Fixed(self.rate)
354    }
355
356    fn check(&self, quota: u64) -> Check {
357        match self.inner.check(quota) {
358            Ok(()) => Check::Ok,
359            Err(duration) => Check::Retry(duration),
360        }
361    }
362}
363
364/// A GCRA-like leak bucket visual scheduler that never deny request even whose weight is larger than tau and only count TAT.
365#[derive(Debug)]
366pub struct LeakBucket {
367    /// Weight scale per 1.0 unit quota in nanosecond.
368    ///
369    /// scale is always non-zero.
370    ///
371    /// scale = rate / 1 (in second)
372    scale: AtomicU64,
373
374    /// Last request's TAT (Theoretical Arrival Time) in nanosecond.
375    ltat: AtomicU64,
376
377    /// Zero time instant.
378    origin: Instant,
379
380    /// Total allowed quotas.
381    total_allowed_quotas: AtomicU64,
382    /// Total waited nanos.
383    total_waited_nanos: AtomicI64,
384}
385
386impl LeakBucket {
387    const NANO: u64 = Duration::from_secs(1).as_nanos() as u64;
388
389    /// calculate the weight scale per 1.0 unit quota in nanosecond.
390    fn scale(rate: NonZeroU64) -> u64 {
391        std::cmp::max(Self::NANO / rate.get(), 1)
392    }
393
394    /// Create a new GCRA-like leak bucket visual scheduler with given rate.
395    fn new(rate: NonZeroU64) -> Self {
396        let scale = Self::scale(rate);
397
398        let origin = Instant::now();
399        let scale = AtomicU64::new(scale);
400
401        Self {
402            scale,
403            ltat: AtomicU64::new(0),
404            origin,
405            total_allowed_quotas: AtomicU64::new(0),
406            total_waited_nanos: AtomicI64::new(0),
407        }
408    }
409
410    /// Check if the request with the given quota is supposed to be allowed at the moment.
411    ///
412    /// On success, the quota will be consumed. The caller is supposed to proceed the quota.
413    ///
414    /// On failure, the minimal duration to retry `check()` is returned.
415    fn check(&self, quota: u64) -> Result<(), Duration> {
416        let now = Instant::now();
417        let tnow = now.duration_since(self.origin).as_nanos() as u64;
418
419        let weight = quota * self.scale.load(Ordering::Relaxed);
420
421        let mut ltat = self.ltat.load(Ordering::Acquire);
422        let tat = loop {
423            let tat = ltat + weight;
424
425            if tat > tnow {
426                self.total_waited_nanos
427                    .fetch_add((tat - tnow) as i64, Ordering::Relaxed);
428                return Err(Duration::from_nanos(tat - tnow));
429            }
430
431            let ltat_new = std::cmp::max(tat, tnow);
432
433            match self
434                .ltat
435                .compare_exchange(ltat, ltat_new, Ordering::Release, Ordering::Acquire)
436            {
437                Ok(_) => break tat,
438                Err(cur) => ltat = cur,
439            }
440        };
441
442        self.total_allowed_quotas
443            .fetch_add(quota, Ordering::Relaxed);
444        self.total_waited_nanos
445            .fetch_sub((tnow - tat) as i64, Ordering::Relaxed);
446
447        Ok(())
448    }
449
450    // // TODO(MrCroxx): Reserved for adaptive rate limiter.
451    /// Average wait time per quota.
452    ///
453    /// Positive value indicates waits, negative value indicates there is spare rate limit.
454    fn _avg_wait_nanos_per_quota(&self) -> i64 {
455        let quotas = self.total_allowed_quotas.load(Ordering::Relaxed);
456        if quotas == 0 {
457            0
458        } else {
459            let nanos = self.total_waited_nanos.load(Ordering::Relaxed);
460            nanos / quotas as i64
461        }
462    }
463
464    // // TODO(MrCroxx): Reserved for adaptive rate limiter.
465    /// Reset statistics.
466    fn _reset_stats(&self) {
467        self.total_allowed_quotas.store(0, Ordering::Relaxed);
468        self.total_waited_nanos.store(0, Ordering::Relaxed);
469    }
470
471    // TODO(MrCroxx): Reserved for adaptive rate limiter.
472    /// Update rate limit with the given rate.
473    fn _update(&self, rate: NonZeroU64) {
474        let scale = Self::scale(rate);
475        self.scale.store(scale, Ordering::Relaxed);
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use std::sync::Arc;
482    use std::sync::atomic::Ordering;
483
484    use rand::{Rng, rng as thread_rng};
485
486    use super::*;
487
488    const ERATIO: f64 = 0.05;
489    const THREADS: usize = 8;
490    const RATE: u64 = 1000;
491    const DURATION: Duration = Duration::from_secs(10);
492
493    /// To run this test:
494    ///
495    /// ```bash
496    /// cargo test --package risingwave_common_rate_limit --lib -- tests::test_leak_bucket --exact --show-output --ignored
497    /// ```
498    #[ignore]
499    #[test]
500    fn test_leak_bucket() {
501        let v = Arc::new(AtomicU64::new(0));
502        let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
503        let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
504            let start = Instant::now();
505            loop {
506                if start.elapsed() >= DURATION {
507                    break;
508                }
509                while let Err(dur) = vs.check(quota) {
510                    std::thread::sleep(dur);
511                }
512                if start.elapsed() >= DURATION {
513                    break;
514                }
515
516                v.fetch_add(quota, Ordering::Relaxed);
517            }
518        };
519        let mut handles = vec![];
520        let mut rng = thread_rng();
521        for _ in 0..THREADS {
522            let rate = rng.random_range(10..20);
523            let handle = std::thread::spawn({
524                let v = v.clone();
525                let limiter = lb.clone();
526                move || task(rate, v, limiter)
527            });
528            handles.push(handle);
529        }
530
531        for handle in handles {
532            handle.join().unwrap();
533        }
534
535        let error = (v.load(Ordering::Relaxed) as isize
536            - RATE as isize * DURATION.as_secs() as isize)
537            .unsigned_abs();
538        let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
539        assert!(eratio < ERATIO, "eratio: {}, target: {}", eratio, ERATIO);
540        println!("eratio {eratio} < ERATIO {ERATIO}");
541    }
542
543    /// To run this test:
544    ///
545    /// ```bash
546    /// cargo test --package risingwave_common_rate_limit --lib -- tests::test_leak_bucket_overflow --exact --show-output --ignored
547    /// ```
548    #[ignore]
549    #[test]
550    fn test_leak_bucket_overflow() {
551        let v = Arc::new(AtomicU64::new(0));
552        let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
553        let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
554            let start = Instant::now();
555            loop {
556                if start.elapsed() >= DURATION {
557                    break;
558                }
559                while let Err(dur) = vs.check(quota) {
560                    std::thread::sleep(dur);
561                }
562                if start.elapsed() >= DURATION {
563                    break;
564                }
565
566                v.fetch_add(quota, Ordering::Relaxed);
567            }
568        };
569        let mut handles = vec![];
570        let mut rng = thread_rng();
571        for _ in 0..THREADS {
572            let rate = rng.random_range(500..1500);
573            let handle = std::thread::spawn({
574                let v = v.clone();
575                let limiter = lb.clone();
576                move || task(rate, v, limiter)
577            });
578            handles.push(handle);
579        }
580
581        for handle in handles {
582            handle.join().unwrap();
583        }
584
585        let got = v.load(Ordering::Relaxed);
586        let expected = RATE * DURATION.as_secs();
587        let error = (v.load(Ordering::Relaxed) as isize
588            - RATE as isize * DURATION.as_secs() as isize)
589            .unsigned_abs();
590        let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
591        assert!(
592            eratio < ERATIO,
593            "eratio: {}, target: {}, got: {}, expected: {}",
594            eratio,
595            ERATIO,
596            got,
597            expected
598        );
599        println!("eratio {eratio} < ERATIO {ERATIO}");
600    }
601
602    #[tokio::test]
603    async fn test_pause_and_resume() {
604        let l = Arc::new(RateLimiter::new(RateLimit::Pause));
605
606        let delay = l.wait(1);
607
608        let ll = l.clone();
609        tokio::spawn(async move {
610            tokio::time::sleep(Duration::from_millis(100)).await;
611            ll.update(RateLimit::Disabled);
612        });
613
614        tokio::time::sleep(Duration::from_millis(1000)).await;
615        delay.await;
616    }
617}