risingwave_common/util/panic.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! RisingWave aborts the execution in the panic hook by default to avoid unpredictability and
//! interference in concurrent programming as much as possible. Since the hook is called no matter
//! where the panic occurs, [`std::panic::catch_unwind`] will be a no-op.
//!
//! To workaround this under some circumstances, we provide a task-local flag in practice to
//! indicate whether we're under the context of catching unwind. This is used in the panic hook to
//! decide whether to abort the execution (see the usage of [`is_catching_unwind`]).
//!
//! This module provides several utilities functions wrapping [`std::panic::catch_unwind`] and other
//! related functions to set the flag properly. Calling functions under these contexts will disable
//! the aborting behavior in the panic hook temporarily.
use std::panic::UnwindSafe;
use futures::Future;
use tokio::task::futures::TaskLocalFuture;
use tokio::task_local;
task_local! {
/// A task-local flag indicating whether we're under the context of catching unwind.
static CATCH_UNWIND: ()
}
/// Invokes a closure, capturing the cause of an unwinding panic if one occurs.
///
/// See the module-level documentation for why this is needed.
pub fn rw_catch_unwind<F: FnOnce() -> R + UnwindSafe, R>(f: F) -> std::thread::Result<R> {
CATCH_UNWIND.sync_scope((), || {
#[expect(clippy::disallowed_methods)]
std::panic::catch_unwind(f)
})
}
#[easy_ext::ext(FutureCatchUnwindExt)]
pub impl<F: Future> F {
/// Catches unwinding panics while polling the future.
///
/// See the module-level documentation for why this is needed.
fn rw_catch_unwind(self) -> TaskLocalFuture<(), futures::future::CatchUnwind<Self>>
where
Self: Sized + std::panic::UnwindSafe,
{
CATCH_UNWIND.scope(
(),
#[expect(clippy::disallowed_methods)]
futures::FutureExt::catch_unwind(self),
)
}
}
// TODO: extension for `Stream`.
/// Returns whether the current scope is under the context of catching unwind (by calling
/// `rw_catch_unwind`).
pub fn is_catching_unwind() -> bool {
CATCH_UNWIND.try_with(|_| ()).is_ok()
}
#[cfg(all(test, not(madsim)))]
#[expect(clippy::disallowed_methods)]
mod tests {
use rusty_fork::rusty_fork_test;
use super::*;
/// Simulates the behavior of `risingwave_rt::set_panic_hook`.
fn set_panic_hook() {
let old = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
old(info);
if !is_catching_unwind() {
std::process::abort();
}
}))
}
rusty_fork_test! {
#[test]
#[should_panic] // `rusty_fork` asserts that the forked process succeeds, so this should panic.
fn test_sync_not_work() {
set_panic_hook();
let _result = std::panic::catch_unwind(|| panic!());
}
#[test]
fn test_sync_rw() {
set_panic_hook();
let result = rw_catch_unwind(|| panic!());
assert!(result.is_err());
}
#[test]
fn test_async_rw() {
set_panic_hook();
let fut = async { panic!() }.rw_catch_unwind();
let result = tokio::runtime::Runtime::new().unwrap().block_on(fut);
assert!(result.is_err());
}
}
}