risingwave_common/util/panic.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! RisingWave aborts the execution in the panic hook by default to avoid unpredictability and
16//! interference in concurrent programming as much as possible. Since the hook is called no matter
17//! where the panic occurs, [`std::panic::catch_unwind`] will be a no-op.
18//!
19//! To workaround this under some circumstances, we provide a task-local flag in practice to
20//! indicate whether we're under the context of catching unwind. This is used in the panic hook to
21//! decide whether to abort the execution (see the usage of [`is_catching_unwind`]).
22//!
23//! This module provides several utilities functions wrapping [`std::panic::catch_unwind`] and other
24//! related functions to set the flag properly. Calling functions under these contexts will disable
25//! the aborting behavior in the panic hook temporarily.
26
27use std::panic::UnwindSafe;
28
29use futures::Future;
30use tokio::task::futures::TaskLocalFuture;
31use tokio::task_local;
32
33task_local! {
34 /// A task-local flag indicating whether we're under the context of catching unwind.
35 static CATCH_UNWIND: ()
36}
37
38/// Invokes a closure, capturing the cause of an unwinding panic if one occurs.
39///
40/// See the module-level documentation for why this is needed.
41pub fn rw_catch_unwind<F: FnOnce() -> R + UnwindSafe, R>(f: F) -> std::thread::Result<R> {
42 CATCH_UNWIND.sync_scope((), || {
43 #[expect(clippy::disallowed_methods)]
44 std::panic::catch_unwind(f)
45 })
46}
47
48#[easy_ext::ext(FutureCatchUnwindExt)]
49pub impl<F: Future> F {
50 /// Catches unwinding panics while polling the future.
51 ///
52 /// See the module-level documentation for why this is needed.
53 fn rw_catch_unwind(self) -> TaskLocalFuture<(), futures::future::CatchUnwind<Self>>
54 where
55 Self: Sized + std::panic::UnwindSafe,
56 {
57 CATCH_UNWIND.scope(
58 (),
59 #[expect(clippy::disallowed_methods)]
60 futures::FutureExt::catch_unwind(self),
61 )
62 }
63}
64
65// TODO: extension for `Stream`.
66
67/// Returns whether the current scope is under the context of catching unwind (by calling
68/// `rw_catch_unwind`).
69pub fn is_catching_unwind() -> bool {
70 CATCH_UNWIND.try_with(|_| ()).is_ok()
71}
72
73#[cfg(all(test, not(madsim)))]
74#[expect(clippy::disallowed_methods)]
75mod tests {
76
77 use rusty_fork::rusty_fork_test;
78
79 use super::*;
80
81 /// Simulates the behavior of `risingwave_rt::set_panic_hook`.
82 fn set_panic_hook() {
83 let old = std::panic::take_hook();
84 std::panic::set_hook(Box::new(move |info| {
85 old(info);
86
87 if !is_catching_unwind() {
88 std::process::abort();
89 }
90 }))
91 }
92
93 rusty_fork_test! {
94 #[test]
95 #[should_panic] // `rusty_fork` asserts that the forked process succeeds, so this should panic.
96 fn test_sync_not_work() {
97 set_panic_hook();
98
99 let _result = std::panic::catch_unwind(|| panic!());
100 }
101
102 #[test]
103 fn test_sync_rw() {
104 set_panic_hook();
105
106 let result = rw_catch_unwind(|| panic!());
107 assert!(result.is_err());
108 }
109
110 #[test]
111 fn test_async_rw() {
112 set_panic_hook();
113
114 let fut = async { panic!() }.rw_catch_unwind();
115
116 let result = tokio::runtime::Runtime::new().unwrap().block_on(fut);
117 assert!(result.is_err());
118 }
119 }
120}