risingwave_common/util/
panic.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! RisingWave aborts the execution in the panic hook by default to avoid unpredictability and
16//! interference in concurrent programming as much as possible. Since the hook is called no matter
17//! where the panic occurs, [`std::panic::catch_unwind`] will be a no-op.
18//!
19//! To workaround this under some circumstances, we provide a task-local flag in practice to
20//! indicate whether we're under the context of catching unwind. This is used in the panic hook to
21//! decide whether to abort the execution (see the usage of [`is_catching_unwind`]).
22//!
23//! This module provides several utilities functions wrapping [`std::panic::catch_unwind`] and other
24//! related functions to set the flag properly. Calling functions under these contexts will disable
25//! the aborting behavior in the panic hook temporarily.
26
27use std::panic::UnwindSafe;
28
29use futures::Future;
30use tokio::task::futures::TaskLocalFuture;
31use tokio::task_local;
32
33task_local! {
34    /// A task-local flag indicating whether we're under the context of catching unwind.
35    static CATCH_UNWIND: ()
36}
37
38/// Invokes a closure, capturing the cause of an unwinding panic if one occurs.
39///
40/// See the module-level documentation for why this is needed.
41pub fn rw_catch_unwind<F: FnOnce() -> R + UnwindSafe, R>(f: F) -> std::thread::Result<R> {
42    CATCH_UNWIND.sync_scope((), || {
43        #[expect(clippy::disallowed_methods)]
44        std::panic::catch_unwind(f)
45    })
46}
47
48#[easy_ext::ext(FutureCatchUnwindExt)]
49pub impl<F: Future> F {
50    /// Catches unwinding panics while polling the future.
51    ///
52    /// See the module-level documentation for why this is needed.
53    fn rw_catch_unwind(self) -> TaskLocalFuture<(), futures::future::CatchUnwind<Self>>
54    where
55        Self: Sized + std::panic::UnwindSafe,
56    {
57        CATCH_UNWIND.scope(
58            (),
59            #[expect(clippy::disallowed_methods)]
60            futures::FutureExt::catch_unwind(self),
61        )
62    }
63}
64
65// TODO: extension for `Stream`.
66
67/// Returns whether the current scope is under the context of catching unwind (by calling
68/// `rw_catch_unwind`).
69pub fn is_catching_unwind() -> bool {
70    CATCH_UNWIND.try_with(|_| ()).is_ok()
71}
72
73#[cfg(all(test, not(madsim)))]
74#[expect(clippy::disallowed_methods)]
75mod tests {
76
77    use rusty_fork::rusty_fork_test;
78
79    use super::*;
80
81    /// Simulates the behavior of `risingwave_rt::set_panic_hook`.
82    fn set_panic_hook() {
83        let old = std::panic::take_hook();
84        std::panic::set_hook(Box::new(move |info| {
85            old(info);
86
87            if !is_catching_unwind() {
88                std::process::abort();
89            }
90        }))
91    }
92
93    rusty_fork_test! {
94        #[test]
95        #[should_panic] // `rusty_fork` asserts that the forked process succeeds, so this should panic.
96        fn test_sync_not_work() {
97            set_panic_hook();
98
99            let _result = std::panic::catch_unwind(|| panic!());
100        }
101
102        #[test]
103        fn test_sync_rw() {
104            set_panic_hook();
105
106            let result = rw_catch_unwind(|| panic!());
107            assert!(result.is_err());
108        }
109
110        #[test]
111        fn test_async_rw() {
112            set_panic_hook();
113
114            let fut = async { panic!() }.rw_catch_unwind();
115
116            let result = tokio::runtime::Runtime::new().unwrap().block_on(fut);
117            assert!(result.is_err());
118        }
119    }
120}