risingwave_frontend/webhook/
mod.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16use std::sync::Arc;
17use std::sync::atomic::AtomicU32;
18
19use anyhow::{Context, anyhow};
20use axum::Router;
21use axum::body::Bytes;
22use axum::extract::{Extension, Path};
23use axum::http::{HeaderMap, Method, StatusCode};
24use axum::routing::post;
25use risingwave_common::array::{Array, ArrayBuilder, DataChunk};
26use risingwave_common::secret::LocalSecretManager;
27use risingwave_common::types::{DataType, JsonbVal, Scalar};
28use risingwave_pb::catalog::WebhookSourceInfo;
29use risingwave_pb::task_service::{FastInsertRequest, FastInsertResponse};
30use tokio::net::TcpListener;
31use tower::ServiceBuilder;
32use tower_http::add_extension::AddExtensionLayer;
33use tower_http::compression::CompressionLayer;
34use tower_http::cors::{self, CorsLayer};
35
36use crate::webhook::utils::{Result, err};
37mod utils;
38use risingwave_rpc_client::ComputeClient;
39
40pub type Service = Arc<WebhookService>;
41
42// We always use the `root` user to connect to the database to allow the webhook service to access all tables.
43const USER: &str = "root";
44
45#[derive(Clone)]
46pub struct FastInsertContext {
47    pub webhook_source_info: WebhookSourceInfo,
48    pub fast_insert_request: FastInsertRequest,
49    pub compute_client: ComputeClient,
50}
51
52pub struct WebhookService {
53    webhook_addr: SocketAddr,
54    counter: AtomicU32,
55}
56
57pub(super) mod handlers {
58    use jsonbb::Value;
59    use risingwave_common::array::JsonbArrayBuilder;
60    use risingwave_common::session_config::SearchPath;
61    use risingwave_pb::catalog::WebhookSourceInfo;
62    use risingwave_pb::task_service::fast_insert_response;
63    use utils::{header_map_to_json, verify_signature};
64
65    use super::*;
66    use crate::catalog::root_catalog::SchemaPath;
67    use crate::scheduler::choose_fast_insert_client;
68    use crate::session::SESSION_MANAGER;
69
70    pub async fn handle_post_request(
71        Extension(srv): Extension<Service>,
72        headers: HeaderMap,
73        Path((database, schema, table)): Path<(String, String, String)>,
74        body: Bytes,
75    ) -> Result<()> {
76        let request_id = srv
77            .counter
78            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
79        let FastInsertContext {
80            webhook_source_info,
81            mut fast_insert_request,
82            compute_client,
83        } = acquire_table_info(request_id, &database, &schema, &table).await?;
84
85        let WebhookSourceInfo {
86            signature_expr,
87            secret_ref,
88            wait_for_persistence: _,
89            is_batched,
90        } = webhook_source_info;
91
92        let is_valid = if let Some(signature_expr) = signature_expr {
93            let secret_string = if let Some(secret_ref) = secret_ref {
94                LocalSecretManager::global()
95                    .fill_secret(secret_ref)
96                    .map_err(|e| err(e, StatusCode::NOT_FOUND))?
97            } else {
98                String::new()
99            };
100
101            // Once limitation here is that the key is no longer case-insensitive, users must user the lowercase key when defining the webhook source table.
102            let headers_jsonb = header_map_to_json(&headers);
103
104            // verify the signature
105            verify_signature(
106                headers_jsonb,
107                secret_string.as_str(),
108                body.as_ref(),
109                signature_expr,
110            )
111            .await?
112        } else {
113            true
114        };
115
116        if !is_valid {
117            return Err(err(
118                anyhow!("Signature verification failed"),
119                StatusCode::UNAUTHORIZED,
120            ));
121        }
122
123        let data_chunk = generate_data_chunk(is_batched, &body)?;
124
125        // fill the data_chunk
126        fast_insert_request.data_chunk = Some(data_chunk.to_protobuf());
127        // execute on the compute node
128        let res = execute(fast_insert_request, compute_client).await?;
129
130        if res.status == fast_insert_response::Status::Succeeded as i32 {
131            Ok(())
132        } else {
133            Err(err(
134                anyhow!("Failed to fast insert: {}", res.error_message),
135                StatusCode::INTERNAL_SERVER_ERROR,
136            ))
137        }
138    }
139
140    fn generate_data_chunk(is_batched: bool, body: &Bytes) -> Result<DataChunk> {
141        let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb);
142
143        if !is_batched {
144            // Use builder to obtain a single column & single row DataChunk
145            let json_value = Value::from_text(body).map_err(|e| {
146                err(
147                    anyhow!(e).context("Failed to parse body"),
148                    StatusCode::UNPROCESSABLE_ENTITY,
149                )
150            })?;
151
152            let jsonb_val = JsonbVal::from(json_value);
153            builder.append(Some(jsonb_val.as_scalar_ref()));
154
155            Ok(DataChunk::new(vec![builder.finish().into_ref()], 1))
156        } else {
157            let rows: Vec<_> = body.split(|&b| b == b'\n').collect();
158
159            for row in &rows {
160                let json_value = Value::from_text(row).map_err(|e| {
161                    err(
162                        anyhow!(e).context("Failed to parse body"),
163                        StatusCode::UNPROCESSABLE_ENTITY,
164                    )
165                })?;
166                let jsonb_val = JsonbVal::from(json_value);
167
168                builder.append(Some(jsonb_val.as_scalar_ref()));
169            }
170
171            Ok(DataChunk::new(
172                vec![builder.finish().into_ref()],
173                rows.len(),
174            ))
175        }
176    }
177
178    async fn acquire_table_info(
179        request_id: u32,
180        database: &String,
181        schema: &String,
182        table: &String,
183    ) -> Result<FastInsertContext> {
184        let session_mgr = SESSION_MANAGER
185            .get()
186            .expect("session manager has been initialized");
187
188        let frontend_env = session_mgr.env();
189
190        let search_path = SearchPath::default();
191        let schema_path = SchemaPath::new(Some(schema.as_str()), &search_path, USER);
192
193        let (webhook_source_info, table_id, version_id, row_id_index) = {
194            let reader = frontend_env.catalog_reader().read_guard();
195            let (table_catalog, _schema) = reader
196                .get_any_table_by_name(database.as_str(), schema_path, table)
197                .map_err(|e| err(e, StatusCode::NOT_FOUND))?;
198
199            let webhook_source_info = table_catalog
200                .webhook_info
201                .as_ref()
202                .ok_or_else(|| {
203                    err(
204                        anyhow!("Table `{}` is not with webhook source", table),
205                        StatusCode::FORBIDDEN,
206                    )
207                })?
208                .clone();
209            (
210                webhook_source_info,
211                table_catalog.id(),
212                table_catalog.version_id().expect("table must be versioned"),
213                table_catalog.row_id_index.map(|idx| idx as u32),
214            )
215        };
216
217        let fast_insert_request = FastInsertRequest {
218            table_id,
219            table_version_id: version_id,
220            column_indices: vec![0],
221            // leave the data_chunk empty for now
222            data_chunk: None,
223            row_id_index,
224            request_id,
225            wait_for_persistence: webhook_source_info.wait_for_persistence,
226        };
227
228        let compute_client = choose_fast_insert_client(table_id, frontend_env, request_id)
229            .await
230            .unwrap();
231
232        Ok(FastInsertContext {
233            webhook_source_info,
234            fast_insert_request,
235            compute_client,
236        })
237    }
238
239    async fn execute(
240        request: FastInsertRequest,
241        client: ComputeClient,
242    ) -> Result<FastInsertResponse> {
243        let response = client.fast_insert(request).await.map_err(|e| {
244            err(
245                anyhow!(e).context("Failed to execute on compute node"),
246                StatusCode::INTERNAL_SERVER_ERROR,
247            )
248        })?;
249        Ok(response)
250    }
251}
252
253impl WebhookService {
254    pub fn new(webhook_addr: SocketAddr) -> Self {
255        Self {
256            webhook_addr,
257            counter: AtomicU32::new(0),
258        }
259    }
260
261    pub async fn serve(self) -> anyhow::Result<()> {
262        use handlers::*;
263        let srv = Arc::new(self);
264
265        let cors_layer = CorsLayer::new()
266            .allow_origin(cors::Any)
267            .allow_methods(vec![Method::POST]);
268
269        let api_router: Router = Router::new()
270            .route("/:database/:schema/:table", post(handle_post_request))
271            .layer(
272                ServiceBuilder::new()
273                    .layer(AddExtensionLayer::new(srv.clone()))
274                    .into_inner(),
275            )
276            .layer(cors_layer);
277
278        let app: Router = Router::new()
279            .nest("/webhook", api_router)
280            .layer(CompressionLayer::new());
281
282        let listener = TcpListener::bind(&srv.webhook_addr)
283            .await
284            .context("Failed to bind dashboard address")?;
285
286        #[cfg(not(madsim))]
287        axum::serve(listener, app)
288            .await
289            .context("Failed to serve dashboard service")?;
290
291        Ok(())
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use std::net::SocketAddr;
298
299    #[tokio::test]
300    #[ignore]
301    async fn test_webhook_server() -> anyhow::Result<()> {
302        let addr = SocketAddr::from(([127, 0, 0, 1], 4560));
303        let service = crate::webhook::WebhookService::new(addr);
304        service.serve().await?;
305        Ok(())
306    }
307}