risingwave_frontend/webhook/
mod.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16use std::sync::Arc;
17use std::sync::atomic::AtomicU32;
18
19use anyhow::{Context, anyhow};
20use axum::Router;
21use axum::body::Bytes;
22use axum::extract::{Extension, Path};
23use axum::http::{HeaderMap, Method, StatusCode};
24use axum::routing::post;
25#[cfg(not(madsim))]
26use axum_server::tls_openssl::OpenSSLConfig;
27use itertools::Itertools;
28use pgwire::pg_protocol::TlsConfig;
29use risingwave_common::array::{Array, ArrayBuilder, DataChunk};
30use risingwave_common::catalog::TableId;
31use risingwave_common::session_config::SearchPath;
32use risingwave_common::types::{DataType, JsonbVal, Scalar};
33use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
34use risingwave_pb::catalog::WebhookSourceInfo;
35use risingwave_pb::task_service::{FastInsertRequest, FastInsertResponse};
36use tokio::net::TcpListener;
37use tower::ServiceBuilder;
38use tower_http::add_extension::AddExtensionLayer;
39use tower_http::compression::CompressionLayer;
40use tower_http::cors::{self, CorsLayer};
41
42use crate::catalog::root_catalog::SchemaPath;
43use crate::scheduler::choose_fast_insert_client;
44use crate::session::SESSION_MANAGER;
45use crate::webhook::payload::{build_json_access_builder, owned_row_from_payload_row};
46use crate::webhook::utils::{Result, authenticate_webhook_payload, err, header_map_to_json};
47pub(crate) mod payload;
48pub(crate) mod utils;
49pub(crate) mod websocket;
50use risingwave_rpc_client::ComputeClient;
51
52pub type Service = Arc<WebhookService>;
53
54// We always use the `root` user to connect to the database to allow the webhook service to access all tables.
55const USER: &str = "root";
56
57#[derive(Clone, Debug)]
58pub(crate) struct WebhookTableColumnDesc {
59    pub(crate) name: String,
60    pub(crate) data_type: DataType,
61    pub(crate) is_pk: bool,
62}
63
64#[derive(Clone, Debug)]
65pub(crate) enum PayloadSchema {
66    SingleJsonb,
67    FullSchema {
68        columns: Vec<WebhookTableColumnDesc>,
69    },
70}
71
72impl PayloadSchema {
73    fn new(columns: Vec<WebhookTableColumnDesc>) -> Self {
74        if columns.len() == 1 && columns[0].data_type == DataType::Jsonb {
75            Self::SingleJsonb
76        } else {
77            Self::FullSchema { columns }
78        }
79    }
80}
81
82#[derive(Clone)]
83pub(crate) struct WebhookTableInsertContext {
84    pub(crate) webhook_source_info: WebhookSourceInfo,
85    pub(crate) table_id: TableId,
86    pub(crate) table_version_id: u64,
87    pub(crate) row_id_index: Option<u32>,
88    pub(crate) compute_client: ComputeClient,
89    pub(crate) payload_schema: PayloadSchema,
90}
91
92pub struct WebhookService {
93    webhook_addr: SocketAddr,
94    tls_config: Option<TlsConfig>,
95    counter: AtomicU32,
96}
97
98pub(super) mod handlers {
99    use jsonbb::Value;
100    use risingwave_common::array::JsonbArrayBuilder;
101    use risingwave_pb::task_service::fast_insert_response;
102
103    use super::*;
104
105    pub async fn handle_post_request(
106        Extension(srv): Extension<Service>,
107        headers: HeaderMap,
108        Path((database, schema, table)): Path<(String, String, String)>,
109        body: Bytes,
110    ) -> Result<()> {
111        let request_id = srv
112            .counter
113            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
114        let WebhookTableInsertContext {
115            webhook_source_info,
116            table_id,
117            table_version_id,
118            row_id_index,
119            compute_client,
120            payload_schema,
121        } = acquire_table_info(request_id, &database, &schema, &table).await?;
122        authenticate_webhook_payload(
123            header_map_to_json(&headers),
124            body.as_ref(),
125            &webhook_source_info,
126        )
127        .await?;
128
129        let data_chunk = match &payload_schema {
130            PayloadSchema::SingleJsonb => {
131                generate_data_chunk(webhook_source_info.is_batched, &body)?
132            }
133            PayloadSchema::FullSchema { columns } => {
134                let rows: Vec<_> = if webhook_source_info.is_batched {
135                    body.split(|&byte| byte == b'\n')
136                        .filter(|b| !b.is_empty())
137                        .collect()
138                } else {
139                    vec![body.as_ref()]
140                };
141                let mut access_builder = build_json_access_builder(&headers)?;
142                let mut chunk_builder = DataChunkBuilder::new(
143                    columns
144                        .iter()
145                        .map(|column| column.data_type.clone())
146                        .collect_vec(),
147                    rows.len().saturating_add(1).max(1),
148                );
149
150                for row in rows {
151                    let owned_row = owned_row_from_payload_row(&mut access_builder, columns, row)?;
152                    assert!(chunk_builder.append_one_row(owned_row).is_none());
153                }
154
155                let Some(chunk) = chunk_builder.consume_all() else {
156                    return Ok(());
157                };
158
159                chunk
160            }
161        };
162
163        let fast_insert_request = FastInsertRequest {
164            table_id,
165            table_version_id,
166            data_chunk: Some(data_chunk.to_protobuf()),
167            row_id_index,
168            request_id,
169            wait_for_persistence: webhook_source_info.wait_for_persistence,
170        };
171        // execute on the compute node
172        let res = execute(fast_insert_request, compute_client).await?;
173
174        if res.status == fast_insert_response::Status::Succeeded as i32 {
175            Ok(())
176        } else {
177            Err(err(
178                anyhow!("Failed to fast insert: {}", res.error_message),
179                StatusCode::INTERNAL_SERVER_ERROR,
180            ))
181        }
182    }
183
184    fn generate_data_chunk(is_batched: bool, body: &Bytes) -> Result<DataChunk> {
185        let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb);
186
187        if !is_batched {
188            // Use builder to obtain a single column & single row DataChunk
189            let json_value = Value::from_text(body).map_err(|e| {
190                err(
191                    anyhow!(e).context("Failed to parse body"),
192                    StatusCode::UNPROCESSABLE_ENTITY,
193                )
194            })?;
195
196            let jsonb_val = JsonbVal::from(json_value);
197            builder.append(Some(jsonb_val.as_scalar_ref()));
198
199            Ok(DataChunk::new(vec![builder.finish().into_ref()], 1))
200        } else {
201            let rows: Vec<_> = body
202                .split(|&b| b == b'\n')
203                .filter(|b| !b.is_empty())
204                .collect();
205
206            for row in &rows {
207                let json_value = Value::from_text(row).map_err(|e| {
208                    err(
209                        anyhow!(e).context("Failed to parse body"),
210                        StatusCode::UNPROCESSABLE_ENTITY,
211                    )
212                })?;
213                let jsonb_val = JsonbVal::from(json_value);
214
215                builder.append(Some(jsonb_val.as_scalar_ref()));
216            }
217
218            Ok(DataChunk::new(
219                vec![builder.finish().into_ref()],
220                rows.len(),
221            ))
222        }
223    }
224
225    pub(crate) async fn acquire_table_info(
226        request_id: u32,
227        database: &str,
228        schema: &str,
229        table: &str,
230    ) -> Result<WebhookTableInsertContext> {
231        let session_mgr = SESSION_MANAGER
232            .get()
233            .expect("session manager has been initialized");
234
235        let frontend_env = session_mgr.env();
236
237        let search_path = SearchPath::default();
238        let schema_path = SchemaPath::new(Some(schema), &search_path, USER);
239
240        let (webhook_source_info, table_id, table_version_id, row_id_index, payload_schema) = {
241            let reader = frontend_env.catalog_reader().read_guard();
242            let (table_catalog, _schema) = reader
243                .get_any_table_by_name(database, schema_path, table)
244                .map_err(|e| err(e, StatusCode::NOT_FOUND))?;
245
246            let (columns_to_insert, row_id_index) = table_catalog.columns_to_insert();
247            let payload_schema = PayloadSchema::new(
248                columns_to_insert
249                    .map(|(column, is_pk)| WebhookTableColumnDesc {
250                        is_pk,
251                        name: column.column_desc.name.clone(),
252                        data_type: column.column_desc.data_type.clone(),
253                    })
254                    .collect(),
255            );
256            let row_id_index = row_id_index.map(|row_id_index| row_id_index as u32);
257
258            let webhook_source_info = table_catalog
259                .webhook_info
260                .as_ref()
261                .ok_or_else(|| {
262                    err(
263                        anyhow!("Table `{}` is not with webhook source", table),
264                        StatusCode::FORBIDDEN,
265                    )
266                })?
267                .clone();
268            (
269                webhook_source_info,
270                table_catalog.id(),
271                table_catalog.version_id().expect("table must be versioned"),
272                row_id_index,
273                payload_schema,
274            )
275        };
276
277        let compute_client = choose_fast_insert_client(table_id, frontend_env, request_id)
278            .await
279            .unwrap();
280
281        Ok(WebhookTableInsertContext {
282            webhook_source_info,
283            table_id,
284            table_version_id,
285            row_id_index,
286            compute_client,
287            payload_schema,
288        })
289    }
290
291    async fn execute(
292        request: FastInsertRequest,
293        client: ComputeClient,
294    ) -> Result<FastInsertResponse> {
295        let response = client.fast_insert(request).await.map_err(|e| {
296            err(
297                anyhow!(e).context("Failed to execute on compute node"),
298                StatusCode::INTERNAL_SERVER_ERROR,
299            )
300        })?;
301        Ok(response)
302    }
303}
304
305pub(crate) use handlers::acquire_table_info;
306
307impl WebhookService {
308    pub fn new(webhook_addr: SocketAddr, tls_config: Option<TlsConfig>) -> Self {
309        Self {
310            webhook_addr,
311            tls_config,
312            counter: AtomicU32::new(0),
313        }
314    }
315
316    pub async fn serve(self) -> anyhow::Result<()> {
317        use handlers::*;
318        let srv = Arc::new(self);
319
320        let cors_layer = CorsLayer::new()
321            .allow_origin(cors::Any)
322            .allow_methods(vec![Method::POST]);
323
324        let webhook_router: Router = Router::new()
325            .route("/{database}/{schema}/{table}", post(handle_post_request))
326            .layer(
327                ServiceBuilder::new()
328                    .layer(AddExtensionLayer::new(srv.clone()))
329                    .into_inner(),
330            )
331            .layer(cors_layer);
332
333        // The ingest WebSocket endpoint shares the same listener as the webhook service.
334        let ingest_svc = Arc::new(websocket::IngestService::new());
335        let ingest_router = websocket::build_router(ingest_svc);
336
337        let app: Router = Router::new()
338            .nest("/webhook", webhook_router)
339            .nest("/ingest", ingest_router)
340            .layer(CompressionLayer::new());
341
342        #[cfg(not(madsim))]
343        {
344            if let Some(tls_config) = &srv.tls_config {
345                let config = OpenSSLConfig::from_pem_file(&tls_config.cert, &tls_config.key)
346                    .context("Failed to load TLS config for webhook service")?;
347                axum_server::bind_openssl(srv.webhook_addr, config)
348                    .serve(app.into_make_service())
349                    .await
350                    .context("Failed to serve webhook service over TLS")?;
351            } else {
352                let listener = TcpListener::bind(&srv.webhook_addr)
353                    .await
354                    .context("Failed to bind dashboard address")?;
355                axum::serve(listener, app)
356                    .await
357                    .context("Failed to serve dashboard service")?;
358            }
359        }
360
361        Ok(())
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use std::net::SocketAddr;
368
369    #[tokio::test]
370    #[ignore]
371    async fn test_webhook_server() -> anyhow::Result<()> {
372        let addr = SocketAddr::from(([127, 0, 0, 1], 4560));
373        let service = crate::webhook::WebhookService::new(addr, None);
374        service.serve().await?;
375        Ok(())
376    }
377}