risingwave_common/util/
recursive.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Track the recursion and grow the stack when necessary to enable fearless recursion.
16
17use std::cell::RefCell;
18
19// See documentation of `stacker` for the meaning of these constants.
20// TODO: determine good values or make them configurable
21const RED_ZONE: usize = 128 * 1024; // 128KiB
22const STACK_SIZE: usize = 16 * RED_ZONE; // 2MiB
23
24/// Recursion depth.
25struct Depth {
26    /// The current depth.
27    current: usize,
28    /// The max depth reached so far, not considering the current depth.
29    last_max: usize,
30}
31
32impl Depth {
33    const fn new() -> Self {
34        Self {
35            current: 0,
36            last_max: 0,
37        }
38    }
39
40    fn reset(&mut self) {
41        *self = Self::new();
42    }
43}
44
45/// The tracker for a recursive function.
46pub struct Tracker {
47    depth: RefCell<Depth>,
48}
49
50impl Tracker {
51    /// Create a new tracker.
52    pub const fn new() -> Self {
53        Self {
54            depth: RefCell::new(Depth::new()),
55        }
56    }
57
58    /// Retrieve the current depth of the recursion. Starts from 1 once the
59    /// recursive function is called.
60    pub fn depth(&self) -> usize {
61        self.depth.borrow().current
62    }
63
64    /// Check if the current depth reaches the given depth **for the first time**.
65    ///
66    /// This is useful for logging without any duplication.
67    pub fn depth_reaches(&self, depth: usize) -> bool {
68        let d = self.depth.borrow();
69        d.current == depth && d.current > d.last_max
70    }
71
72    /// Run a recursive function. Grow the stack if necessary.
73    fn recurse<T>(&self, f: impl FnOnce() -> T) -> T {
74        struct DepthGuard<'a> {
75            depth: &'a RefCell<Depth>,
76        }
77
78        impl<'a> DepthGuard<'a> {
79            fn new(depth: &'a RefCell<Depth>) -> Self {
80                depth.borrow_mut().current += 1;
81                Self { depth }
82            }
83        }
84
85        impl Drop for DepthGuard<'_> {
86            fn drop(&mut self) {
87                let mut d = self.depth.borrow_mut();
88                d.last_max = d.last_max.max(d.current); // update the last max depth
89                d.current -= 1; // restore the current depth
90                if d.current == 0 {
91                    d.reset(); // reset state if the recursion is finished
92                }
93            }
94        }
95
96        let _guard = DepthGuard::new(&self.depth);
97
98        if cfg!(madsim) {
99            f() // madsim does not support stack growth
100        } else {
101            stacker::maybe_grow(RED_ZONE, STACK_SIZE, f)
102        }
103    }
104}
105
106/// The extension trait for a thread-local tracker to run a recursive function.
107#[easy_ext::ext(Recurse)]
108impl std::thread::LocalKey<Tracker> {
109    /// Run the given recursive function. Grow the stack if necessary.
110    ///
111    /// # Fearless Recursion
112    ///
113    /// This enables fearless recursion in most cases as long as a single frame
114    /// does not exceed the [`RED_ZONE`] size. That is, the caller can recurse
115    /// as much as it wants without worrying about stack overflow.
116    ///
117    /// # Tracker
118    ///
119    /// The caller can retrieve the [`Tracker`] of the current recursion from
120    /// the closure argument. This can be useful for checking the depth of the
121    /// recursion, logging or throwing an error gracefully if it's too deep.
122    ///
123    /// Note that different trackers defined in different functions are
124    /// independent of each other. If there's a cross-function recursion, the
125    /// tracker retrieved from the closure argument only represents the current
126    /// function's state.
127    ///
128    /// # Example
129    ///
130    /// Define the tracker with [`tracker!`] and call this method on it to run
131    /// a recursive function.
132    ///
133    /// ```ignore
134    /// #[inline(never)]
135    /// fn sum(x: u64) -> u64 {
136    ///     tracker!().recurse(|t| {
137    ///         if t.depth() % 100000 == 0 {
138    ///            eprintln!("too deep!");
139    ///         }
140    ///         if x == 0 {
141    ///             return 0;
142    ///         }
143    ///         x + sum(x - 1)
144    ///     })
145    /// }
146    /// ```
147    pub fn recurse<T>(&'static self, f: impl FnOnce(&Tracker) -> T) -> T {
148        self.with(|t| t.recurse(|| f(t)))
149    }
150}
151
152/// Define the tracker for recursion and return it.
153///
154/// Call [`Recurse::recurse`] on it to run a recursive function. See
155/// documentation there for usage.
156#[macro_export]
157macro_rules! __recursive_tracker {
158    () => {{
159        use $crate::util::recursive::Tracker;
160        std::thread_local! {
161            static __TRACKER: Tracker = const { Tracker::new() };
162        }
163        __TRACKER
164    }};
165}
166pub use __recursive_tracker as tracker;
167
168#[cfg(all(test, not(madsim)))]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_fearless_recursion() {
174        const X: u64 = 1919810;
175        const EXPECTED: u64 = 1842836177955;
176
177        #[inline(never)]
178        fn sum(x: u64) -> u64 {
179            tracker!().recurse(|t| {
180                if x == 0 {
181                    assert_eq!(t.depth(), X as usize + 1);
182                    return 0;
183                }
184                x + sum(x - 1)
185            })
186        }
187
188        assert_eq!(sum(X), EXPECTED);
189    }
190}