risingwave_common_log/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::num::{NonZeroU32, NonZeroUsize};
16use std::sync::atomic::{AtomicUsize, Ordering};
17
18use governor::Quota;
19
20type RateLimiter = governor::RateLimiter<
21    governor::state::NotKeyed,
22    governor::state::InMemoryState,
23    governor::clock::MonotonicClock,
24>;
25
26/// `LogSuppressor` is a helper to suppress log spamming.
27pub struct LogSuppressor {
28    /// The number of times the log has been suppressed. Will be returned and cleared when the
29    /// rate limiter allows next log to be printed.
30    suppressed_count: AtomicUsize,
31
32    /// Inner rate limiter.
33    rate_limiter: RateLimiter,
34}
35
36#[derive(Debug)]
37pub struct LogSuppressed;
38
39impl LogSuppressor {
40    pub fn new(rate_limiter: RateLimiter) -> Self {
41        Self {
42            suppressed_count: AtomicUsize::new(0),
43            rate_limiter,
44        }
45    }
46
47    /// Create a `LogSuppressor` that allows `per_second` logs per second.
48    pub fn per_second(per_second: u32) -> Self {
49        Self::new(RateLimiter::direct(Quota::per_second(
50            NonZeroU32::new(per_second).unwrap(),
51        )))
52    }
53
54    /// Create a `LogSuppressor` that allows `per_minute` logs per minute.
55    pub fn per_minute(per_minute: u32) -> Self {
56        Self::new(RateLimiter::direct(Quota::per_minute(
57            NonZeroU32::new(per_minute).unwrap(),
58        )))
59    }
60
61    /// Check if the log should be suppressed.
62    /// If the log should be suppressed, return `Err(LogSuppressed)`.
63    /// Otherwise, return `Ok(Some(..))` with count of suppressed messages since last check,
64    /// or `Ok(None)` if there's none.
65    pub fn check(&self) -> core::result::Result<Option<NonZeroUsize>, LogSuppressed> {
66        match self.rate_limiter.check() {
67            Ok(()) => Ok(NonZeroUsize::new(
68                self.suppressed_count.swap(0, Ordering::Relaxed),
69            )),
70            Err(_) => {
71                self.suppressed_count.fetch_add(1, Ordering::Relaxed);
72                Err(LogSuppressed)
73            }
74        }
75    }
76}
77
78impl Default for LogSuppressor {
79    /// Default rate limiter allows 1 log per second.
80    fn default() -> Self {
81        Self::per_second(1)
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use std::sync::LazyLock;
88    use std::time::{Duration, Instant};
89
90    use tracing_subscriber::util::SubscriberInitExt;
91
92    use super::*;
93
94    #[tokio::test]
95    async fn demo() {
96        let _logger = tracing_subscriber::fmt::Subscriber::builder()
97            .with_max_level(tracing::Level::ERROR)
98            .set_default();
99
100        let mut interval = tokio::time::interval(Duration::from_millis(10));
101
102        let mut allowed = 0;
103        let mut suppressed = 0;
104
105        let start = Instant::now();
106
107        for _ in 0..1000 {
108            interval.tick().await;
109            static RATE_LIMITER: LazyLock<LogSuppressor> =
110                LazyLock::new(|| LogSuppressor::per_second(5));
111
112            if let Ok(suppressed_count) = RATE_LIMITER.check() {
113                suppressed += suppressed_count.map(|v| v.get()).unwrap_or_default();
114                allowed += 1;
115                tracing::error!(suppressed_count, "failed to foo bar");
116            }
117        }
118        let duration = Instant::now().duration_since(start);
119
120        tracing::error!(
121            allowed,
122            suppressed,
123            ?duration,
124            rate = allowed as f64 / duration.as_secs_f64()
125        );
126    }
127}