risingwave_error/
tonic.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod extra;
16
17use std::borrow::Cow;
18use std::error::Error;
19use std::sync::Arc;
20
21use serde::{Deserialize, Serialize};
22use thiserror_ext::AsReport;
23use tonic::metadata::{MetadataMap, MetadataValue};
24
25/// The key of the metadata field that contains the serialized error.
26const ERROR_KEY: &str = "risingwave-error-bin";
27
28/// The key of the metadata field that contains the call name.
29pub const CALL_KEY: &str = "risingwave-grpc-call";
30
31/// The service name that the error is from. Used to provide better error message.
32// TODO: also make it a field of `Extra`?
33type ServiceName = Cow<'static, str>;
34
35/// The error produced by the gRPC server and sent to the client on the wire.
36#[derive(Debug, Serialize, Deserialize)]
37struct ServerError {
38    error: serde_error::Error,
39    service_name: Option<ServiceName>,
40    extra: extra::Extra,
41}
42
43impl std::fmt::Display for ServerError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        self.error.fmt(f)
46    }
47}
48
49impl std::error::Error for ServerError {
50    fn source(&self) -> Option<&(dyn Error + 'static)> {
51        self.error.source()
52    }
53
54    fn provide<'a>(&'a self, request: &mut std::error::Request<'a>) {
55        // Provide self so that `ErrorIsFromTonicServerImpl` can work.
56        request.provide_ref(self);
57        // Provide extra fields.
58        self.extra.provide(request);
59    }
60}
61
62fn to_status<T>(error: &T, code: tonic::Code, service_name: Option<ServiceName>) -> tonic::Status
63where
64    T: ?Sized + std::error::Error,
65{
66    // Embed the whole error (`self`) and its source chain into the details field.
67    // At the same time, set the message field to the error message of `self` (without source chain).
68    // The redundancy of the current error's message is intentional in case the client ignores the `details` field.
69    let source = ServerError {
70        error: serde_error::Error::new(error),
71        service_name,
72        extra: extra::Extra::new(error),
73    };
74    let serialized = bincode::serialize(&source).unwrap();
75
76    let mut metadata = MetadataMap::new();
77    metadata.insert_bin(ERROR_KEY, MetadataValue::from_bytes(&serialized));
78
79    let mut status = tonic::Status::with_metadata(code, error.to_report_string(), metadata);
80    // Set the source of `tonic::Status`, though it's not likely to be used.
81    // This is only available before serializing to the wire. That's why we need to manually embed it
82    // into the `details` field.
83    status.set_source(Arc::new(source));
84    status
85}
86
87// TODO(error-handling): disallow constructing `tonic::Status` directly with `new` by clippy.
88#[easy_ext::ext(ToTonicStatus)]
89impl<T> T
90where
91    T: ?Sized + std::error::Error,
92{
93    /// Convert the error to [`tonic::Status`] with the given [`tonic::Code`] and service name.
94    ///
95    /// The source chain is preserved by pairing with [`TonicStatusWrapper`].
96    pub fn to_status(
97        &self,
98        code: tonic::Code,
99        service_name: impl Into<ServiceName>,
100    ) -> tonic::Status {
101        to_status(self, code, Some(service_name.into()))
102    }
103
104    /// Convert the error to [`tonic::Status`] with the given [`tonic::Code`] without specifying
105    /// the service name. Prefer [`to_status`] if possible.
106    ///
107    /// The source chain is preserved by pairing with [`TonicStatusWrapper`].
108    pub fn to_status_unnamed(&self, code: tonic::Code) -> tonic::Status {
109        to_status(self, code, None)
110    }
111}
112
113#[easy_ext::ext(ErrorIsFromTonicServerImpl)]
114impl<T> T
115where
116    T: ?Sized + std::error::Error,
117{
118    /// Returns whether the error is from the implementation of a tonic server, i.e., created
119    /// with [`ToTonicStatus::to_status`].
120    ///
121    /// This does not count errors initiated from the library, typically connection issues.
122    /// As a result, this function can be used to decide whether an error should be retried.
123    pub fn is_from_tonic_server_impl(&self) -> bool {
124        std::error::request_ref::<ServerError>(self).is_some()
125    }
126}
127
128/// A wrapper of [`tonic::Status`] that provides better error message and extracts
129/// the source chain from the `details` field.
130#[derive(Debug)]
131pub struct TonicStatusWrapper {
132    inner: tonic::Status,
133
134    /// The call name (path) of the gRPC request.
135    call: Option<String>,
136
137    /// Optional service name from the client side.
138    ///
139    /// # Explanation
140    ///
141    /// [`tonic::Status`] is used for both client and server side. When the error is created on
142    /// the server side, we encourage developers to provide the service name with
143    /// [`ToTonicStatus::to_status`], so that the info can be included in the HTTP response and
144    /// then extracted by the client side (in [`TonicStatusWrapper::new`]).
145    ///
146    /// However, if there's something wrong with the server side and the error is directly
147    /// created on the client side, the approach above is not applicable. In this case, the
148    /// caller should set a "client side" service name to provide better error message. This is
149    /// achieved by [`TonicStatusWrapperExt::with_client_side_service_name`].
150    client_side_service_name: Option<ServiceName>,
151}
152
153impl TonicStatusWrapper {
154    /// Create a new [`TonicStatusWrapper`] from the given [`tonic::Status`] and extract
155    /// the source chain from its `details` field.
156    pub fn new(mut status: tonic::Status) -> Self {
157        if status.source().is_none() {
158            if let Some(value) = status.metadata().get_bin(ERROR_KEY) {
159                if let Some(e) = value.to_bytes().ok().and_then(|serialized| {
160                    bincode::deserialize::<ServerError>(serialized.as_ref()).ok()
161                }) {
162                    status.set_source(Arc::new(e));
163                } else {
164                    tracing::warn!("failed to deserialize error from gRPC metadata");
165                }
166            }
167        }
168
169        let call = status
170            .metadata()
171            .get(CALL_KEY)
172            .and_then(|value| value.to_str().ok())
173            .map(str::to_owned);
174
175        Self {
176            inner: status,
177            call,
178            client_side_service_name: None,
179        }
180    }
181
182    /// Returns the reference to the inner [`tonic::Status`].
183    pub fn inner(&self) -> &tonic::Status {
184        &self.inner
185    }
186
187    /// Consumes `self` and returns the inner [`tonic::Status`].
188    pub fn into_inner(self) -> tonic::Status {
189        self.inner
190    }
191}
192
193impl From<tonic::Status> for TonicStatusWrapper {
194    fn from(status: tonic::Status) -> Self {
195        Self::new(status)
196    }
197}
198
199impl std::fmt::Display for TonicStatusWrapper {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        write!(f, "gRPC request")?;
202
203        if let Some(service_name) = self
204            .source()
205            .and_then(|s| s.downcast_ref::<ServerError>())
206            .and_then(|s| s.service_name.as_ref())
207            // if no service name from the server side, use the client side one
208            .or(self.client_side_service_name.as_ref())
209        {
210            write!(f, " to {} service", service_name)?;
211        }
212        if let Some(call) = &self.call {
213            write!(f, " (call `{}`)", call)?;
214        }
215        write!(f, " failed: {}: ", self.inner.code())?;
216
217        #[expect(rw::format_error)] // intentionally format the source itself
218        if let Some(source) = self.source() {
219            // Prefer the source chain from the `details` field.
220            write!(f, "{}", source)
221        } else {
222            write!(f, "{}", self.inner.message())
223        }
224    }
225}
226
227#[easy_ext::ext(TonicStatusWrapperExt)]
228impl<T> T
229where
230    T: Into<TonicStatusWrapper>,
231{
232    /// Set the client side service name to provide better error message.
233    ///
234    /// See the documentation on the field `client_side_service_name` for more details.
235    pub fn with_client_side_service_name(
236        self,
237        service_name: impl Into<ServiceName>,
238    ) -> TonicStatusWrapper {
239        let mut this = self.into();
240        this.client_side_service_name = Some(service_name.into());
241        this
242    }
243}
244
245impl std::error::Error for TonicStatusWrapper {
246    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
247        // Delegate to `self.inner` as if we're transparent.
248        self.inner.source()
249    }
250
251    fn provide<'a>(&'a self, request: &mut std::error::Request<'a>) {
252        // The source error, typically a `ServerError`, may provide additional information through `extra`.
253        if let Some(source) = self.source() {
254            source.provide(request);
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_source_chain_preserved() {
265        #[derive(thiserror::Error, Debug)]
266        #[error("{message}")]
267        struct MyError {
268            message: &'static str,
269            source: Option<Box<MyError>>,
270        }
271
272        let original = MyError {
273            message: "outer",
274            source: Some(Box::new(MyError {
275                message: "inner",
276                source: None,
277            })),
278        };
279
280        let server_status = original.to_status(tonic::Code::Internal, "test");
281        let body = server_status.into_http();
282        let client_status = tonic::Status::from_header_map(body.headers()).unwrap();
283
284        let wrapper = TonicStatusWrapper::new(client_status);
285        assert_eq!(
286            wrapper.to_string(),
287            "gRPC request to test service failed: Internal error: outer"
288        );
289
290        let source = wrapper.source().unwrap();
291        assert!(source.is::<ServerError>());
292        assert_eq!(source.to_string(), "outer");
293        assert_eq!(source.source().unwrap().to_string(), "inner");
294    }
295}