risingwave_common/util/recursive.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Track the recursion and grow the stack when necessary to enable fearless recursion.
16
17use std::cell::RefCell;
18
19// See documentation of `stacker` for the meaning of these constants.
20// TODO: determine good values or make them configurable
21const RED_ZONE: usize = 128 * 1024; // 128KiB
22const STACK_SIZE: usize = 16 * RED_ZONE; // 2MiB
23
24/// Recursion depth.
25struct Depth {
26 /// The current depth.
27 current: usize,
28 /// The max depth reached so far, not considering the current depth.
29 last_max: usize,
30}
31
32impl Depth {
33 const fn new() -> Self {
34 Self {
35 current: 0,
36 last_max: 0,
37 }
38 }
39
40 fn reset(&mut self) {
41 *self = Self::new();
42 }
43}
44
45/// The tracker for a recursive function.
46pub struct Tracker {
47 depth: RefCell<Depth>,
48}
49
50impl Tracker {
51 /// Create a new tracker.
52 pub const fn new() -> Self {
53 Self {
54 depth: RefCell::new(Depth::new()),
55 }
56 }
57
58 /// Retrieve the current depth of the recursion. Starts from 1 once the
59 /// recursive function is called.
60 pub fn depth(&self) -> usize {
61 self.depth.borrow().current
62 }
63
64 /// Check if the current depth reaches the given depth **for the first time**.
65 ///
66 /// This is useful for logging without any duplication.
67 pub fn depth_reaches(&self, depth: usize) -> bool {
68 let d = self.depth.borrow();
69 d.current == depth && d.current > d.last_max
70 }
71
72 /// Run a recursive function. Grow the stack if necessary.
73 fn recurse<T>(&self, f: impl FnOnce() -> T) -> T {
74 struct DepthGuard<'a> {
75 depth: &'a RefCell<Depth>,
76 }
77
78 impl<'a> DepthGuard<'a> {
79 fn new(depth: &'a RefCell<Depth>) -> Self {
80 depth.borrow_mut().current += 1;
81 Self { depth }
82 }
83 }
84
85 impl Drop for DepthGuard<'_> {
86 fn drop(&mut self) {
87 let mut d = self.depth.borrow_mut();
88 d.last_max = d.last_max.max(d.current); // update the last max depth
89 d.current -= 1; // restore the current depth
90 if d.current == 0 {
91 d.reset(); // reset state if the recursion is finished
92 }
93 }
94 }
95
96 let _guard = DepthGuard::new(&self.depth);
97
98 if cfg!(madsim) {
99 f() // madsim does not support stack growth
100 } else {
101 stacker::maybe_grow(RED_ZONE, STACK_SIZE, f)
102 }
103 }
104}
105
106/// The extension trait for a thread-local tracker to run a recursive function.
107#[easy_ext::ext(Recurse)]
108impl std::thread::LocalKey<Tracker> {
109 /// Run the given recursive function. Grow the stack if necessary.
110 ///
111 /// # Fearless Recursion
112 ///
113 /// This enables fearless recursion in most cases as long as a single frame
114 /// does not exceed the [`RED_ZONE`] size. That is, the caller can recurse
115 /// as much as it wants without worrying about stack overflow.
116 ///
117 /// # Tracker
118 ///
119 /// The caller can retrieve the [`Tracker`] of the current recursion from
120 /// the closure argument. This can be useful for checking the depth of the
121 /// recursion, logging or throwing an error gracefully if it's too deep.
122 ///
123 /// Note that different trackers defined in different functions are
124 /// independent of each other. If there's a cross-function recursion, the
125 /// tracker retrieved from the closure argument only represents the current
126 /// function's state.
127 ///
128 /// # Example
129 ///
130 /// Define the tracker with [`tracker!`] and call this method on it to run
131 /// a recursive function.
132 ///
133 /// ```ignore
134 /// #[inline(never)]
135 /// fn sum(x: u64) -> u64 {
136 /// tracker!().recurse(|t| {
137 /// if t.depth() % 100000 == 0 {
138 /// eprintln!("too deep!");
139 /// }
140 /// if x == 0 {
141 /// return 0;
142 /// }
143 /// x + sum(x - 1)
144 /// })
145 /// }
146 /// ```
147 pub fn recurse<T>(&'static self, f: impl FnOnce(&Tracker) -> T) -> T {
148 self.with(|t| t.recurse(|| f(t)))
149 }
150}
151
152/// Define the tracker for recursion and return it.
153///
154/// Call [`Recurse::recurse`] on it to run a recursive function. See
155/// documentation there for usage.
156#[macro_export]
157macro_rules! __recursive_tracker {
158 () => {{
159 use $crate::util::recursive::Tracker;
160 std::thread_local! {
161 static __TRACKER: Tracker = const { Tracker::new() };
162 }
163 __TRACKER
164 }};
165}
166pub use __recursive_tracker as tracker;
167
168#[cfg(all(test, not(madsim)))]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_fearless_recursion() {
174 const X: u64 = 1919810;
175 const EXPECTED: u64 = 1842836177955;
176
177 #[inline(never)]
178 fn sum(x: u64) -> u64 {
179 tracker!().recurse(|t| {
180 if x == 0 {
181 assert_eq!(t.depth(), X as usize + 1);
182 return 0;
183 }
184 x + sum(x - 1)
185 })
186 }
187
188 assert_eq!(sum(X), EXPECTED);
189 }
190}