risingwave_common_rate_limit/
lib.rsuse std::future::Future;
use std::num::NonZeroU64;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, LazyLock};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use arc_swap::ArcSwap;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use risingwave_common::catalog::TableId;
use risingwave_common::metrics::LabelGuardedUintGaugeVec;
use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
use risingwave_common_metrics::{
register_guarded_uint_gauge_vec_with_registry, LabelGuardedUintGauge,
};
use tokio::sync::oneshot;
use tokio::time::Sleep;
static METRICS: LazyLock<LabelGuardedUintGaugeVec<1>> = LazyLock::new(|| {
register_guarded_uint_gauge_vec_with_registry!(
"backfill_rate_limit_bytes",
"backfill rate limit bytes per second",
&["table_id"],
&GLOBAL_METRICS_REGISTRY
)
.unwrap()
});
pin_project! {
#[derive(Debug)]
#[project = DelayProj]
pub enum Delay {
Noop,
Sleep{#[pin] sleep: Sleep},
Wait{#[pin] rx: oneshot::Receiver<()> },
Infinite,
}
}
impl Delay {
pub fn new(duration: Duration) -> Self {
match duration {
Duration::ZERO => Self::Noop,
Duration::MAX => Self::Infinite,
dur => Self::Sleep {
sleep: tokio::time::sleep(dur),
},
}
}
}
impl From<Duration> for Delay {
fn from(value: Duration) -> Self {
Self::new(value)
}
}
impl Future for Delay {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
DelayProj::Noop => Poll::Ready(()),
DelayProj::Sleep { sleep } => sleep.poll(cx),
DelayProj::Wait { rx } => rx.poll(cx).map(|_| ()),
DelayProj::Infinite => Poll::Pending,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimit {
Disabled,
Fixed(NonZeroU64),
Pause,
}
impl RateLimit {
pub fn is_paused(&self) -> bool {
matches! { self, Self::Pause }
}
pub fn to_u64(self) -> u64 {
self.into()
}
}
impl From<RateLimit> for u64 {
fn from(rate_limit: RateLimit) -> Self {
match rate_limit {
RateLimit::Disabled => u64::MAX,
RateLimit::Fixed(rate) => rate.get(),
RateLimit::Pause => 0,
}
}
}
impl From<Option<u32>> for RateLimit {
fn from(value: Option<u32>) -> Self {
match value {
None => Self::Disabled,
Some(0) => Self::Pause,
Some(rate) => Self::Fixed(unsafe { NonZeroU64::new_unchecked(rate as _) }),
}
}
}
#[derive(Debug)]
pub enum Check {
Ok,
Retry(Duration),
RetryAfter(oneshot::Receiver<()>),
}
impl Check {
pub fn is_ok(&self) -> bool {
matches!(self, Self::Ok)
}
}
pub trait RateLimiterTrait: Send + Sync + 'static {
fn rate_limit(&self) -> RateLimit;
fn check(&self, quota: u64) -> Check;
}
pub struct RateLimiter {
inner: ArcSwap<Box<dyn RateLimiterTrait>>,
}
impl RateLimiter {
fn new_inner(rate_limit: RateLimit) -> Box<dyn RateLimiterTrait> {
match rate_limit {
RateLimit::Disabled => Box::new(InfiniteRatelimiter),
RateLimit::Fixed(rate) => Box::new(FixedRateLimiter::new(rate)),
RateLimit::Pause => Box::new(PausedRateLimiter::default()),
}
}
pub fn new(rate_limit: RateLimit) -> Self {
let inner: Box<dyn RateLimiterTrait> = Self::new_inner(rate_limit);
let inner = ArcSwap::new(Arc::new(inner));
Self { inner }
}
pub fn update(&self, rate_limit: RateLimit) -> RateLimit {
let old = self.rate_limit();
if self.rate_limit() == rate_limit {
return old;
}
let inner = Self::new_inner(rate_limit);
self.inner.store(Arc::new(inner));
old
}
pub fn monitored(self, table_id: impl Into<TableId>) -> MonitoredRateLimiter {
let metric = METRICS.with_guarded_label_values(&[&table_id.into().to_string()]);
let rate_limit = AtomicU64::new(self.rate_limit().to_u64());
MonitoredRateLimiter {
inner: self,
metric,
rate_limit,
}
}
pub fn rate_limit(&self) -> RateLimit {
self.inner.load().rate_limit()
}
pub fn check(&self, quota: u64) -> Check {
self.inner.load().check(quota)
}
pub async fn wait(&self, quota: u64) {
loop {
match self.check(quota) {
Check::Ok => return,
Check::Retry(duration) => {
tokio::time::sleep(duration).await;
}
Check::RetryAfter(rx) => {
let _ = rx.await;
}
}
}
}
}
impl RateLimiterTrait for RateLimiter {
fn rate_limit(&self) -> RateLimit {
self.rate_limit()
}
fn check(&self, quota: u64) -> Check {
self.check(quota)
}
}
pub struct MonitoredRateLimiter {
inner: RateLimiter,
metric: LabelGuardedUintGauge<1>,
rate_limit: AtomicU64,
}
impl Deref for MonitoredRateLimiter {
type Target = RateLimiter;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl RateLimiterTrait for MonitoredRateLimiter {
fn rate_limit(&self) -> RateLimit {
self.inner.rate_limit()
}
fn check(&self, quota: u64) -> Check {
let check = self.inner.check(quota);
if matches! { check, Check::Ok} {
self.report();
}
check
}
}
impl MonitoredRateLimiter {
fn report(&self) {
let rate_limit = self.inner.rate_limit().to_u64();
if rate_limit != self.rate_limit.load(Ordering::Relaxed) {
self.rate_limit.store(rate_limit, Ordering::Relaxed);
self.metric.set(rate_limit);
}
}
}
#[derive(Debug)]
pub struct InfiniteRatelimiter;
impl RateLimiterTrait for InfiniteRatelimiter {
fn rate_limit(&self) -> RateLimit {
RateLimit::Disabled
}
fn check(&self, _: u64) -> Check {
Check::Ok
}
}
#[derive(Debug)]
pub struct PausedRateLimiter {
waiters: Mutex<Vec<oneshot::Sender<()>>>,
}
impl Default for PausedRateLimiter {
fn default() -> Self {
Self {
waiters: Mutex::new(vec![]),
}
}
}
impl Drop for PausedRateLimiter {
fn drop(&mut self) {
for tx in self.waiters.lock().drain(..) {
let _ = tx.send(());
}
}
}
impl RateLimiterTrait for PausedRateLimiter {
fn rate_limit(&self) -> RateLimit {
RateLimit::Pause
}
fn check(&self, _: u64) -> Check {
let (tx, rx) = oneshot::channel();
self.waiters.lock().push(tx);
Check::RetryAfter(rx)
}
}
#[derive(Debug)]
pub struct FixedRateLimiter {
inner: LeakBucket,
rate: NonZeroU64,
}
impl FixedRateLimiter {
pub fn new(rate: NonZeroU64) -> Self {
let inner = LeakBucket::new(rate);
Self { inner, rate }
}
}
impl RateLimiterTrait for FixedRateLimiter {
fn rate_limit(&self) -> RateLimit {
RateLimit::Fixed(self.rate)
}
fn check(&self, quota: u64) -> Check {
match self.inner.check(quota) {
Ok(()) => Check::Ok,
Err(duration) => Check::Retry(duration),
}
}
}
#[derive(Debug)]
pub struct LeakBucket {
scale: AtomicU64,
ltat: AtomicU64,
origin: Instant,
total_allowed_quotas: AtomicU64,
total_waited_nanos: AtomicI64,
}
impl LeakBucket {
const NANO: u64 = Duration::from_secs(1).as_nanos() as u64;
fn scale(rate: NonZeroU64) -> u64 {
std::cmp::max(Self::NANO / rate.get(), 1)
}
fn new(rate: NonZeroU64) -> Self {
let scale = Self::scale(rate);
let origin = Instant::now();
let scale = AtomicU64::new(scale);
Self {
scale,
ltat: AtomicU64::new(0),
origin,
total_allowed_quotas: AtomicU64::new(0),
total_waited_nanos: AtomicI64::new(0),
}
}
fn check(&self, quota: u64) -> Result<(), Duration> {
let now = Instant::now();
let tnow = now.duration_since(self.origin).as_nanos() as u64;
let weight = quota * self.scale.load(Ordering::Relaxed);
let mut ltat = self.ltat.load(Ordering::Acquire);
let tat = loop {
let tat = ltat + weight;
if tat > tnow {
self.total_waited_nanos
.fetch_add((tat - tnow) as i64, Ordering::Relaxed);
return Err(Duration::from_nanos(tat - tnow));
}
let ltat_new = std::cmp::max(tat, tnow);
match self
.ltat
.compare_exchange(ltat, ltat_new, Ordering::Release, Ordering::Acquire)
{
Ok(_) => break tat,
Err(cur) => ltat = cur,
}
};
self.total_allowed_quotas
.fetch_add(quota, Ordering::Relaxed);
self.total_waited_nanos
.fetch_sub((tnow - tat) as i64, Ordering::Relaxed);
Ok(())
}
fn _avg_wait_nanos_per_quota(&self) -> i64 {
let quotas = self.total_allowed_quotas.load(Ordering::Relaxed);
if quotas == 0 {
0
} else {
let nanos = self.total_waited_nanos.load(Ordering::Relaxed);
nanos / quotas as i64
}
}
fn _reset_stats(&self) {
self.total_allowed_quotas.store(0, Ordering::Relaxed);
self.total_waited_nanos.store(0, Ordering::Relaxed);
}
fn _update(&self, rate: NonZeroU64) {
let scale = Self::scale(rate);
self.scale.store(scale, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::Ordering;
use std::sync::Arc;
use rand::{thread_rng, Rng};
use super::*;
const ERATIO: f64 = 0.05;
const THREADS: usize = 8;
const RATE: u64 = 1000;
const DURATION: Duration = Duration::from_secs(10);
#[ignore]
#[test]
fn test_leak_bucket() {
let v = Arc::new(AtomicU64::new(0));
let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
let start = Instant::now();
loop {
if start.elapsed() >= DURATION {
break;
}
while let Err(dur) = vs.check(quota) {
std::thread::sleep(dur);
}
if start.elapsed() >= DURATION {
break;
}
v.fetch_add(quota, Ordering::Relaxed);
}
};
let mut handles = vec![];
let mut rng = thread_rng();
for _ in 0..THREADS {
let rate = rng.gen_range(10..20);
let handle = std::thread::spawn({
let v = v.clone();
let limiter = lb.clone();
move || task(rate, v, limiter)
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let error = (v.load(Ordering::Relaxed) as isize
- RATE as isize * DURATION.as_secs() as isize)
.unsigned_abs();
let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
assert!(eratio < ERATIO, "eratio: {}, target: {}", eratio, ERATIO);
println!("eratio {eratio} < ERATIO {ERATIO}");
}
#[ignore]
#[test]
fn test_leak_bucket_overflow() {
let v = Arc::new(AtomicU64::new(0));
let lb = Arc::new(LeakBucket::new(RATE.try_into().unwrap()));
let task = |quota: u64, v: Arc<AtomicU64>, vs: Arc<LeakBucket>| {
let start = Instant::now();
loop {
if start.elapsed() >= DURATION {
break;
}
while let Err(dur) = vs.check(quota) {
std::thread::sleep(dur);
}
if start.elapsed() >= DURATION {
break;
}
v.fetch_add(quota, Ordering::Relaxed);
}
};
let mut handles = vec![];
let mut rng = thread_rng();
for _ in 0..THREADS {
let rate = rng.gen_range(500..1500);
let handle = std::thread::spawn({
let v = v.clone();
let limiter = lb.clone();
move || task(rate, v, limiter)
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let got = v.load(Ordering::Relaxed);
let expected = RATE * DURATION.as_secs();
let error = (v.load(Ordering::Relaxed) as isize
- RATE as isize * DURATION.as_secs() as isize)
.unsigned_abs();
let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
assert!(
eratio < ERATIO,
"eratio: {}, target: {}, got: {}, expected: {}",
eratio,
ERATIO,
got,
expected
);
println!("eratio {eratio} < ERATIO {ERATIO}");
}
#[tokio::test]
async fn test_pause_and_resume() {
let l = Arc::new(RateLimiter::new(RateLimit::Pause));
let delay = l.wait(1);
let ll = l.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
ll.update(RateLimit::Disabled);
});
tokio::time::sleep(Duration::from_millis(1000)).await;
delay.await;
}
}