risingwave_frontend/webhook/
utils.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use anyhow::anyhow;
18use axum::Json;
19use axum::http::{HeaderMap, StatusCode};
20use axum::response::IntoResponse;
21use risingwave_common::row::OwnedRow;
22use risingwave_common::secret::LocalSecretManager;
23use risingwave_common::types::JsonbVal;
24use risingwave_pb::expr::ExprNode;
25use serde_json::json;
26use thiserror_ext::AsReport;
27
28use crate::expr::ExprImpl;
29
30#[derive(Debug)]
31pub struct WebhookError {
32    err: anyhow::Error,
33    code: StatusCode,
34}
35
36pub(crate) type Result<T> = std::result::Result<T, WebhookError>;
37
38pub(crate) fn err(err: impl Into<anyhow::Error>, code: StatusCode) -> WebhookError {
39    WebhookError {
40        err: err.into(),
41        code,
42    }
43}
44
45impl WebhookError {
46    #[cfg(test)]
47    pub(crate) fn code(&self) -> StatusCode {
48        self.code
49    }
50}
51
52impl From<anyhow::Error> for WebhookError {
53    fn from(value: anyhow::Error) -> Self {
54        WebhookError {
55            err: value,
56            code: StatusCode::INTERNAL_SERVER_ERROR,
57        }
58    }
59}
60
61impl std::fmt::Display for WebhookError {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(f, "{}", self.err.as_report())
64    }
65}
66
67impl std::error::Error for WebhookError {}
68
69impl IntoResponse for WebhookError {
70    fn into_response(self) -> axum::response::Response {
71        let mut resp = Json(json!({
72            "error": format!("{}", self.err.as_report()),
73        }))
74        .into_response();
75        *resp.status_mut() = self.code;
76        resp
77    }
78}
79
80pub(crate) fn header_map_to_json(headers: &HeaderMap) -> JsonbVal {
81    let mut header_map = HashMap::new();
82
83    for (key, value) in headers {
84        let key = key.as_str().to_owned();
85        let value = value.to_str().unwrap_or("").to_owned();
86        header_map.insert(key, value);
87    }
88
89    let json_value = json!(header_map);
90    JsonbVal::from(json_value)
91}
92
93pub(crate) async fn authenticate_webhook_payload(
94    headers_jsonb: JsonbVal,
95    payload: &[u8],
96    webhook_source_info: &risingwave_pb::catalog::WebhookSourceInfo,
97) -> Result<()> {
98    let is_valid = if let Some(signature_expr) = webhook_source_info.signature_expr.clone() {
99        let secret = if let Some(secret_ref) = webhook_source_info.secret_ref {
100            LocalSecretManager::global()
101                .fill_secret(secret_ref)
102                .map_err(|e| err(e, StatusCode::NOT_FOUND))?
103        } else {
104            String::new()
105        };
106        verify_signature(headers_jsonb, secret.as_str(), payload, signature_expr).await?
107    } else {
108        true
109    };
110
111    if !is_valid {
112        return Err(err(
113            anyhow!("Signature verification failed"),
114            StatusCode::UNAUTHORIZED,
115        ));
116    }
117
118    Ok(())
119}
120
121pub(crate) async fn verify_signature(
122    headers_jsonb: JsonbVal,
123    secret: &str,
124    payload: &[u8],
125    signature_expr: ExprNode,
126) -> Result<bool> {
127    let row = OwnedRow::new(vec![
128        Some(headers_jsonb.into()),
129        Some(secret.into()),
130        Some(payload.into()),
131    ]);
132
133    let signature_expr_impl = ExprImpl::from_expr_proto(&signature_expr)
134        .map_err(|e| err(e, StatusCode::INTERNAL_SERVER_ERROR))?;
135
136    let result = signature_expr_impl
137        .eval_row(&row)
138        .await
139        .map_err(|e| {
140            tracing::error!(error = %e.as_report(), "Fail to validate for webhook events.");
141            err(e, StatusCode::INTERNAL_SERVER_ERROR)
142        })?
143        .ok_or_else(|| {
144            err(
145                anyhow!("`SECURE_COMPARE()` failed"),
146                StatusCode::BAD_REQUEST,
147            )
148        })?;
149    Ok(*result.as_bool())
150}
151
152#[cfg(test)]
153mod tests {
154    use axum::http::header::HeaderName;
155
156    use super::*;
157
158    #[test]
159    fn test_header_map_to_json_preserves_header_names() {
160        let mut headers = HeaderMap::new();
161        headers.insert(
162            HeaderName::from_static("x-custom-token"),
163            "abc".parse().unwrap(),
164        );
165
166        let headers_json = header_map_to_json(&headers);
167        let json_value: serde_json::Value =
168            serde_json::from_str(&headers_json.to_string()).unwrap();
169
170        assert_eq!(json_value["x-custom-token"], "abc");
171    }
172}