risingwave_frontend/webhook/
websocket.rs

1// Copyright 2026 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! WebSocket-based streaming DML ingest endpoint.
16//!
17//! Clients open a `WebSocket` connection per table. The first text frame authenticates the session
18//! and later frames carry unsigned batches of upsert / delete DML messages. Each websocket
19//! payload carries a monotonically increasing `dml_batch_id`.
20//!
21//! Wire format (JSON over text `WebSocket` frames):
22//!
23//! **Client → Server** (authenticated init):
24//! ```json
25//! {"type": "init", "timestamp": 1760000000000}
26//! ```
27//!
28//! **Client → Server** (DML batch):
29//! ```json
30//! {
31//!   "dml_batch_id": 1,
32//!   "items": [
33//!     {"op": "upsert", "data": {"id": 1, "name": "foo"}},
34//!     {"op": "delete", "data": {"id": 1, "name": "foo"}}
35//!   ]
36//! }
37//! ```
38//!
39//! **Server → Client** (ack):
40//! ```json
41//! {"ack": 1}
42//! ```
43//!
44//! **Server → Client** (fatal — the connection will close):
45//! ```json
46//! {"fatal": "DML channel closed, please reconnect"}
47//! ```
48use std::sync::Arc;
49use std::sync::atomic::{AtomicU32, Ordering};
50use std::time::{Duration, SystemTime, UNIX_EPOCH};
51
52use axum::Router;
53use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
54use axum::extract::{Extension, Path};
55use axum::http::HeaderMap;
56use axum::response::IntoResponse;
57use axum::routing::get;
58use futures::SinkExt;
59use futures::stream::StreamExt;
60use jsonbb::Value;
61use risingwave_common::array::{Op, StreamChunk};
62use risingwave_common::row::OwnedRow;
63use risingwave_common::types::{DataType, JsonbVal};
64use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
65use risingwave_pb::task_service::{
66    IngestDmlInitRequest, IngestDmlPayloadRequest, IngestDmlRequest, ingest_dml_request,
67    ingest_dml_response,
68};
69use serde::{Deserialize, Serialize};
70use serde_json::value::RawValue;
71use thiserror_ext::AsReport;
72use tokio::time::timeout;
73use tokio_stream::wrappers::ReceiverStream;
74use tower::ServiceBuilder;
75use tower_http::add_extension::AddExtensionLayer;
76
77use crate::session::SESSION_MANAGER;
78use crate::webhook::payload::{build_json_access_builder, owned_row_from_payload_row};
79use crate::webhook::utils::{authenticate_webhook_payload, header_map_to_json};
80use crate::webhook::{PayloadSchema, acquire_table_info};
81
82const INIT_MESSAGE_TYPE: &str = "init";
83const INGEST_DML_REQUEST_BUFFER_SIZE: usize = 64;
84const WEBSOCKET_INIT_TIMEOUT: Duration = Duration::from_secs(10);
85
86/// Shared state for the ingest service.
87pub struct IngestService {
88    counter: AtomicU32,
89}
90
91impl IngestService {
92    pub fn new() -> Self {
93        Self {
94            counter: AtomicU32::new(0),
95        }
96    }
97}
98
99pub type ServiceRef = Arc<IngestService>;
100
101#[derive(Debug, Deserialize)]
102struct InitRequest {
103    #[serde(rename = "type")]
104    msg_type: String,
105    timestamp: i64,
106}
107
108#[derive(Debug, Deserialize)]
109struct RawDmlBatchRequest {
110    pub dml_batch_id: u64,
111    pub items: Vec<RawDmlRequest>,
112}
113
114#[derive(Debug, Deserialize)]
115pub struct RawDmlRequest {
116    pub op: Option<String>,
117    // `RawValue` is unsized; `Box<RawValue>` is serde's owned form that preserves the original JSON text.
118    pub data: Box<RawValue>,
119}
120
121#[derive(Debug, Clone, Copy)]
122enum DmlOp {
123    Upsert,
124    Delete,
125}
126
127#[derive(Debug)]
128pub struct DmlRequest {
129    op: DmlOp,
130    data: Box<RawValue>,
131}
132
133impl TryFrom<RawDmlRequest> for DmlRequest {
134    type Error = String;
135
136    fn try_from(raw: RawDmlRequest) -> Result<Self, Self::Error> {
137        let op = match raw.op.as_deref() {
138            None => {
139                return Err("missing op, expected upsert/delete".to_owned());
140            }
141            Some("upsert" | "insert" | "update") => DmlOp::Upsert,
142            Some("delete") => DmlOp::Delete,
143            Some(other) => {
144                return Err(format!("unknown op '{other}', expected upsert/delete"));
145            }
146        };
147
148        Ok(Self { op, data: raw.data })
149    }
150}
151
152#[derive(Debug, Serialize)]
153#[serde(untagged)]
154pub enum ServerMessage {
155    Ack { ack: u64 },
156    Fatal { fatal: String },
157}
158
159#[derive(Debug)]
160struct PreparedDmlBatch {
161    dml_batch_id: u64,
162    payload: IngestDmlPayloadRequest,
163}
164
165type WsTx = futures::stream::SplitSink<WebSocket, Message>;
166type WsRx = futures::stream::SplitStream<WebSocket>;
167
168pub async fn ws_handler(
169    ws: WebSocketUpgrade,
170    Extension(svc): Extension<ServiceRef>,
171    headers: HeaderMap,
172    Path((database, schema, table)): Path<(String, String, String)>,
173) -> impl IntoResponse {
174    let request_id = svc.counter.fetch_add(1, Ordering::Relaxed);
175    let headers_jsonb = header_map_to_json(&headers);
176    ws.on_upgrade(move |socket| {
177        handle_connection(
178            socket,
179            database,
180            schema,
181            table,
182            request_id,
183            headers,
184            headers_jsonb,
185        )
186    })
187}
188
189async fn handle_connection(
190    socket: WebSocket,
191    database: String,
192    schema: String,
193    table: String,
194    request_id: u32,
195    headers: HeaderMap,
196    headers_jsonb: JsonbVal,
197) {
198    let (mut ws_tx, ws_rx) = socket.split();
199
200    if let Err(e) = try_handle_connection(
201        &mut ws_tx,
202        ws_rx,
203        database,
204        schema,
205        table,
206        request_id,
207        headers,
208        headers_jsonb,
209    )
210    .await
211    {
212        let _ = send_fatal(&mut ws_tx, e).await;
213    }
214}
215
216async fn try_handle_connection(
217    ws_tx: &mut WsTx,
218    mut ws_rx: WsRx,
219    database: String,
220    schema: String,
221    table: String,
222    request_id: u32,
223    headers: HeaderMap,
224    headers_jsonb: JsonbVal,
225) -> Result<(), String> {
226    let table_info = acquire_table_info(request_id, &database, &schema, &table)
227        .await
228        .map_err(|e| format!("table lookup failed: {}", e.as_report()))?;
229
230    let session_mgr = SESSION_MANAGER.get().expect("session manager initialized");
231    let max_clock_skew_ms = session_mgr
232        .env()
233        .frontend_config()
234        .webhook_auth_max_clock_skew_ms;
235    let webhook_source_info = table_info.webhook_source_info;
236    let table_id = table_info.table_id;
237    let table_version_id = table_info.table_version_id;
238    let row_id_index = table_info.row_id_index;
239    let compute_client = table_info.compute_client;
240    let payload_schema = table_info.payload_schema;
241
242    let init_text = match timeout(WEBSOCKET_INIT_TIMEOUT, ws_rx.next()).await {
243        Ok(Some(Ok(Message::Text(text)))) => text,
244        Ok(Some(Ok(Message::Close(_)))) | Ok(None) => return Ok(()),
245        Ok(Some(Ok(_))) => {
246            return Err("the first WebSocket frame must be a text init message".to_owned());
247        }
248        Ok(Some(Err(e))) => {
249            return Err(format!(
250                "failed to read WebSocket init message: {}",
251                e.as_report()
252            ));
253        }
254        Err(_) => {
255            return Err(format!(
256                "timed out waiting for WebSocket init message after {}s",
257                WEBSOCKET_INIT_TIMEOUT.as_secs()
258            ));
259        }
260    };
261
262    if let Err(e) =
263        authenticate_webhook_payload(headers_jsonb, init_text.as_bytes(), &webhook_source_info)
264            .await
265    {
266        return Err(format!("{}", e.as_report()));
267    }
268
269    parse_and_validate_init_request(&init_text, max_clock_skew_ms)?;
270
271    let (ingest_req_tx, ingest_req_rx) = tokio::sync::mpsc::channel(INGEST_DML_REQUEST_BUFFER_SIZE);
272    if ingest_req_tx
273        .send(IngestDmlRequest {
274            request: Some(ingest_dml_request::Request::Init(IngestDmlInitRequest {
275                table_id,
276                table_version_id,
277                request_id,
278                row_id_index,
279            })),
280        })
281        .await
282        .is_err()
283    {
284        return Err("failed to enqueue init request for ingest stream".to_owned());
285    }
286
287    let mut ingest_resp_stream = compute_client
288        .ingest_dml(ReceiverStream::new(ingest_req_rx))
289        .await
290        .map_err(|e| format!("failed to open ingest stream: {}", e.as_report()))?;
291
292    match ingest_resp_stream.message().await {
293        Ok(Some(resp)) => match resp.response {
294            Some(ingest_dml_response::Response::Init(_)) => {}
295            _ => {
296                return Err("unexpected init response from ingest stream".to_owned());
297            }
298        },
299        Ok(None) => return Err("ingest stream closed during init".to_owned()),
300        Err(e) => {
301            return Err(format!("ingest stream init error: {}", e.as_report()));
302        }
303    }
304
305    let mut last_seen_dml_batch_id = 0_u64;
306    let mut last_forwarded_dml_batch_id = 0_u64;
307    let mut last_acked_dml_batch_id = 0_u64;
308
309    loop {
310        tokio::select! {
311            biased;
312            ingest_resp = ingest_resp_stream.message() => {
313                match ingest_resp {
314                    Ok(Some(resp)) => match resp.response {
315                        Some(ingest_dml_response::Response::Ack(ack)) => {
316                            if ack.dml_batch_id <= last_acked_dml_batch_id
317                                || ack.dml_batch_id > last_forwarded_dml_batch_id
318                            {
319                                return Err(format!(
320                                    "unexpected ack for dml_batch_id {}",
321                                    ack.dml_batch_id
322                                ));
323                            }
324                            last_acked_dml_batch_id = ack.dml_batch_id;
325                            send_server_message(
326                                ws_tx,
327                                ServerMessage::Ack {
328                                    ack: ack.dml_batch_id,
329                                },
330                            )
331                            .await
332                            .map_err(|_| "websocket connection closed while sending ack".to_owned())?;
333                        }
334                        Some(ingest_dml_response::Response::Init(_)) => {
335                            return Err("unexpected extra init response from ingest stream".to_owned());
336                        }
337                        None => return Err("empty response from ingest stream".to_owned()),
338                    },
339                    Ok(None) => return Err("ingest stream closed".to_owned()),
340                    Err(e) => return Err(format!("ingest stream error: {}", e.as_report())),
341                }
342            }
343
344            ws_msg = ws_rx.next() => {
345                let text = match ws_msg {
346                    Some(Ok(Message::Text(text))) => text,
347                    Some(Ok(Message::Close(_))) | None => break,
348                    Some(Ok(_)) => continue,
349                    Some(Err(e)) => {
350                        return Err(format!(
351                            "failed to read WebSocket message: {}",
352                            e.as_report()
353                        ));
354                    }
355                };
356
357                let raw_dml_batch = match serde_json::from_str::<RawDmlBatchRequest>(&text) {
358                    Ok(batch) => batch,
359                    Err(e) => return Err(format!("malformed payload: {}", e.as_report())),
360                };
361
362                last_seen_dml_batch_id = validate_monotonic_dml_batch_id(
363                    raw_dml_batch.dml_batch_id,
364                    last_seen_dml_batch_id,
365                )?;
366
367                if raw_dml_batch.items.is_empty() {
368                    send_server_message(
369                        ws_tx,
370                        ServerMessage::Ack {
371                            ack: raw_dml_batch.dml_batch_id,
372                        },
373                    )
374                    .await
375                    .map_err(|_| "websocket connection closed while sending ack".to_owned())?;
376                    continue;
377                }
378
379                let prepared_batch = prepare_dml_batch_payload(
380                    &headers,
381                    raw_dml_batch,
382                    &payload_schema,
383                )
384                .map_err(|e| format!("failed to prepare DML batch: {e}"))?;
385
386                let PreparedDmlBatch {
387                    dml_batch_id,
388                    payload,
389                } = prepared_batch;
390
391                ingest_req_tx
392                    .send(IngestDmlRequest {
393                        request: Some(ingest_dml_request::Request::Payload(payload)),
394                    })
395                    .await
396                    .map_err(|_| "ingest stream request channel closed".to_owned())?;
397                last_forwarded_dml_batch_id = dml_batch_id;
398            }
399        }
400    }
401
402    Ok(())
403}
404
405fn parse_and_validate_init_request(
406    text: &str,
407    max_clock_skew_ms: u64,
408) -> Result<InitRequest, String> {
409    let init_req: InitRequest = serde_json::from_str(text)
410        .map_err(|e| format!("malformed init message: {}", e.as_report()))?;
411
412    if init_req.msg_type != INIT_MESSAGE_TYPE {
413        return Err(format!(
414            "invalid init message type '{}', expected '{}'",
415            init_req.msg_type, INIT_MESSAGE_TYPE
416        ));
417    }
418
419    validate_timestamp_skew(init_req.timestamp, max_clock_skew_ms)?;
420    Ok(init_req)
421}
422
423fn validate_timestamp_skew(timestamp_ms: i64, max_clock_skew_ms: u64) -> Result<(), String> {
424    if timestamp_ms < 0 {
425        return Err("timestamp must be a non-negative epoch millisecond".to_owned());
426    }
427
428    let now_ms = SystemTime::now()
429        .duration_since(UNIX_EPOCH)
430        .unwrap_or(Duration::ZERO)
431        .as_millis() as i128;
432    let diff_ms = (now_ms - i128::from(timestamp_ms)).abs();
433
434    if diff_ms > i128::from(max_clock_skew_ms) {
435        return Err(format!(
436            "timestamp skew {}ms exceeds the allowed {}ms window",
437            diff_ms, max_clock_skew_ms
438        ));
439    }
440
441    Ok(())
442}
443
444fn validate_monotonic_dml_batch_id(
445    dml_batch_id: u64,
446    last_seen_dml_batch_id: u64,
447) -> Result<u64, String> {
448    if dml_batch_id <= last_seen_dml_batch_id {
449        return Err(format!(
450            "dml_batch_id must increase monotonically: received {} after {}",
451            dml_batch_id, last_seen_dml_batch_id
452        ));
453    }
454    Ok(dml_batch_id)
455}
456
457fn prepare_dml_batch_payload(
458    headers: &HeaderMap,
459    raw_dml_batch: RawDmlBatchRequest,
460    payload_schema: &PayloadSchema,
461) -> Result<PreparedDmlBatch, String> {
462    let dml_batch_id = raw_dml_batch.dml_batch_id;
463    let raw_dml_reqs = raw_dml_batch.items;
464
465    match payload_schema {
466        PayloadSchema::SingleJsonb => {
467            let mut chunk_builder = DataChunkBuilder::new(
468                vec![DataType::Jsonb],
469                raw_dml_reqs.len().saturating_add(1).max(1),
470            );
471            let mut ops = Vec::with_capacity(raw_dml_reqs.len());
472
473            for (index, raw_dml_req) in raw_dml_reqs.into_iter().enumerate() {
474                let item_index = index + 1;
475                let dml_req = DmlRequest::try_from(raw_dml_req)
476                    .map_err(|e| format!("dml_batch_id {dml_batch_id} item {item_index}: {e}"))?;
477
478                let row = Value::from_text(dml_req.data.get().as_bytes())
479                    .map(|json_value| OwnedRow::new(vec![Some(JsonbVal::from(json_value).into())]))
480                    .map_err(|e| {
481                        format!(
482                            "dml_batch_id {dml_batch_id} item {item_index}: Failed to parse body: {}",
483                            e.as_report()
484                        )
485                    })?;
486
487                let output = chunk_builder.append_one_row(row);
488                debug_assert!(output.is_none());
489                ops.push(match dml_req.op {
490                    DmlOp::Upsert => Op::Insert,
491                    DmlOp::Delete => Op::Delete,
492                });
493            }
494
495            let data_chunk = chunk_builder
496                .consume_all()
497                .expect("buffered rows should produce a chunk");
498            let chunk = StreamChunk::from_parts(ops, data_chunk);
499            let payload = IngestDmlPayloadRequest {
500                dml_batch_id,
501                chunk: Some(chunk.to_protobuf()),
502            };
503
504            Ok(PreparedDmlBatch {
505                dml_batch_id,
506                payload,
507            })
508        }
509        PayloadSchema::FullSchema { columns } => {
510            let mut access_builder =
511                build_json_access_builder(headers).map_err(|e| format!("{}", e.as_report()))?;
512            let mut chunk_builder = DataChunkBuilder::new(
513                columns
514                    .iter()
515                    .map(|column| column.data_type.clone())
516                    .collect(),
517                raw_dml_reqs.len().saturating_add(1).max(1),
518            );
519            let mut ops = Vec::with_capacity(raw_dml_reqs.len());
520
521            for (index, raw_dml_req) in raw_dml_reqs.into_iter().enumerate() {
522                let item_index = index + 1;
523                let dml_req = DmlRequest::try_from(raw_dml_req)
524                    .map_err(|e| format!("dml_batch_id {dml_batch_id} item {item_index}: {e}"))?;
525
526                let row = owned_row_from_payload_row(
527                    &mut access_builder,
528                    columns,
529                    dml_req.data.get().as_bytes(),
530                )
531                .map_err(|e| {
532                    format!(
533                        "dml_batch_id {dml_batch_id} item {item_index}: {}",
534                        e.as_report()
535                    )
536                })?;
537
538                let output = chunk_builder.append_one_row(row);
539                debug_assert!(output.is_none());
540                ops.push(match dml_req.op {
541                    DmlOp::Upsert => Op::Insert,
542                    DmlOp::Delete => Op::Delete,
543                });
544            }
545
546            let data_chunk = chunk_builder
547                .consume_all()
548                .expect("buffered rows should produce a chunk");
549            let chunk = StreamChunk::from_parts(ops, data_chunk);
550            let payload = IngestDmlPayloadRequest {
551                dml_batch_id,
552                chunk: Some(chunk.to_protobuf()),
553            };
554
555            Ok(PreparedDmlBatch {
556                dml_batch_id,
557                payload,
558            })
559        }
560    }
561}
562
563async fn send_server_message(ws_tx: &mut WsTx, msg: ServerMessage) -> Result<(), String> {
564    let text = serde_json::to_string(&msg).map_err(|e| format!("{}", e.as_report()))?;
565    ws_tx
566        .send(Message::Text(text.into()))
567        .await
568        .map_err(|e| format!("{}", e.as_report()))
569}
570
571async fn send_fatal(ws_tx: &mut WsTx, fatal: String) -> Result<(), String> {
572    send_server_message(ws_tx, ServerMessage::Fatal { fatal }).await
573}
574
575pub fn build_router(svc: ServiceRef) -> Router {
576    Router::new()
577        .route("/{database}/{schema}/{table}", get(ws_handler))
578        .layer(
579            ServiceBuilder::new()
580                .layer(AddExtensionLayer::new(svc))
581                .into_inner(),
582        )
583}
584
585#[cfg(test)]
586mod tests {
587    use axum::http::{HeaderMap, HeaderValue};
588    use risingwave_common::row::Row;
589    use risingwave_common::types::{DataType, ScalarImpl, ToOwnedDatum};
590
591    use super::*;
592    use crate::webhook::WebhookTableColumnDesc;
593
594    fn raw_json(text: &str) -> Box<RawValue> {
595        serde_json::from_str(text).unwrap()
596    }
597
598    fn raw_req(op: Option<&str>, data: &str) -> RawDmlRequest {
599        RawDmlRequest {
600            op: op.map(str::to_owned),
601            data: raw_json(data),
602        }
603    }
604
605    fn raw_batch(dml_batch_id: u64, items: Vec<RawDmlRequest>) -> RawDmlBatchRequest {
606        RawDmlBatchRequest {
607            dml_batch_id,
608            items,
609        }
610    }
611
612    fn test_columns(columns: &[(&str, DataType, bool)]) -> Vec<WebhookTableColumnDesc> {
613        columns
614            .iter()
615            .map(|(name, data_type, is_pk)| WebhookTableColumnDesc {
616                name: (*name).to_owned(),
617                data_type: data_type.clone(),
618                is_pk: *is_pk,
619            })
620            .collect()
621    }
622
623    #[test]
624    fn test_validate_monotonic_dml_batch_id() {
625        assert_eq!(validate_monotonic_dml_batch_id(5, 1).unwrap(), 5);
626    }
627
628    #[test]
629    fn test_validate_monotonic_dml_batch_id_rejects_non_monotonic_sequence() {
630        let err = validate_monotonic_dml_batch_id(3, 3).unwrap_err();
631        assert!(err.contains("dml_batch_id must increase monotonically"));
632    }
633
634    #[test]
635    fn test_parse_and_validate_init_request() {
636        let now_ms = SystemTime::now()
637            .duration_since(UNIX_EPOCH)
638            .unwrap_or(Duration::ZERO)
639            .as_millis() as i64;
640        let init = format!(r#"{{"type":"init","timestamp":{now_ms}}}"#);
641        assert_eq!(
642            parse_and_validate_init_request(&init, 300_000)
643                .unwrap()
644                .timestamp,
645            now_ms
646        );
647    }
648
649    #[test]
650    fn test_parse_and_validate_init_request_rejects_stale_timestamp() {
651        let init = r#"{"type":"init","timestamp":0}"#;
652        let err = parse_and_validate_init_request(init, 1).unwrap_err();
653        assert!(err.contains("timestamp skew"));
654    }
655
656    #[test]
657    fn test_dml_request_requires_op() {
658        let err = DmlRequest::try_from(raw_req(None, r#"{"id":1}"#)).unwrap_err();
659        assert_eq!(err, "missing op, expected upsert/delete");
660    }
661
662    #[test]
663    fn test_empty_batch_is_valid_json_protocol_input() {
664        let raw_dml_batch = raw_batch(42, vec![]);
665        assert_eq!(
666            validate_monotonic_dml_batch_id(raw_dml_batch.dml_batch_id, 41).unwrap(),
667            42
668        );
669        assert!(raw_dml_batch.items.is_empty());
670    }
671
672    #[test]
673    fn test_prepare_dml_batch_payload_builds_single_chunk_for_batch() {
674        let raw_dml_batch = raw_batch(
675            10,
676            vec![
677                raw_req(
678                    Some("upsert"),
679                    r#"{"id":1,"price":"19.99","created_at":"2026-04-15 10:00:00","name":"alice"}"#,
680                ),
681                raw_req(
682                    Some("delete"),
683                    r#"{"id":2,"price":"29.99","created_at":"2026-04-16 10:00:00","name":"bob"}"#,
684                ),
685            ],
686        );
687        let payload_schema = PayloadSchema::FullSchema {
688            columns: test_columns(&[
689                ("id", DataType::Int32, true),
690                ("price", DataType::Decimal, false),
691                ("created_at", DataType::Timestamp, false),
692                ("name", DataType::Varchar, false),
693            ]),
694        };
695
696        let prepared_batch =
697            prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema).unwrap();
698
699        assert_eq!(prepared_batch.dml_batch_id, 10);
700
701        let payload = prepared_batch.payload;
702        assert_eq!(payload.dml_batch_id, 10);
703        let chunk = StreamChunk::from_protobuf(payload.chunk.as_ref().unwrap()).unwrap();
704
705        assert_eq!(chunk.ops(), &[Op::Insert, Op::Delete]);
706        let mut rows = chunk.rows();
707        let row = rows.next().unwrap().1;
708        assert_eq!(row.datum_at(0).to_owned_datum(), Some(ScalarImpl::Int32(1)));
709        assert!(matches!(
710            row.datum_at(1).to_owned_datum(),
711            Some(ScalarImpl::Decimal(_))
712        ));
713        assert!(matches!(
714            row.datum_at(2).to_owned_datum(),
715            Some(ScalarImpl::Timestamp(_))
716        ));
717        assert_eq!(
718            row.datum_at(3).to_owned_datum(),
719            Some(ScalarImpl::Utf8("alice".into()))
720        );
721
722        let row = rows.next().unwrap().1;
723        assert_eq!(row.datum_at(0).to_owned_datum(), Some(ScalarImpl::Int32(2)));
724    }
725
726    #[test]
727    fn test_prepare_dml_batch_payload_returns_first_item_error() {
728        let raw_dml_batch = raw_batch(
729            11,
730            vec![
731                raw_req(Some("delete"), r#"{"id":1,"name":"alice"}"#),
732                raw_req(Some("upsert"), r#"{"name":"bob"}"#),
733            ],
734        );
735        let payload_schema = PayloadSchema::FullSchema {
736            columns: test_columns(&[
737                ("id", DataType::Int32, true),
738                ("name", DataType::Varchar, false),
739            ]),
740        };
741
742        let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
743            .unwrap_err();
744
745        assert!(err.contains("dml_batch_id 11 item 2"));
746        assert!(err.contains("failed to decode webhook JSON payload"));
747    }
748
749    #[test]
750    fn test_prepare_dml_batch_payload_allows_delete_with_only_pk() {
751        let raw_dml_batch = raw_batch(
752            12,
753            vec![
754                raw_req(Some("upsert"), r#"{"id":7,"name":"alice"}"#),
755                raw_req(Some("delete"), r#"{"id":7}"#),
756            ],
757        );
758        let payload_schema = PayloadSchema::FullSchema {
759            columns: test_columns(&[
760                ("id", DataType::Int32, true),
761                ("name", DataType::Varchar, false),
762            ]),
763        };
764
765        let prepared_batch =
766            prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema).unwrap();
767        let chunk =
768            StreamChunk::from_protobuf(prepared_batch.payload.chunk.as_ref().unwrap()).unwrap();
769
770        assert_eq!(prepared_batch.dml_batch_id, 12);
771        assert_eq!(chunk.ops(), &[Op::Insert, Op::Delete]);
772        let mut rows = chunk.rows();
773        assert_eq!(
774            rows.next().unwrap().1.datum_at(1).to_owned_datum(),
775            Some(ScalarImpl::Utf8("alice".into()))
776        );
777        assert_eq!(rows.next().unwrap().1.datum_at(1).to_owned_datum(), None);
778    }
779
780    #[test]
781    fn test_prepare_dml_batch_payload_rejects_incomplete_composite_pk_insert() {
782        let raw_dml_batch = raw_batch(
783            51,
784            vec![raw_req(Some("upsert"), r#"{"id":1,"name":"alice"}"#)],
785        );
786        let payload_schema = PayloadSchema::FullSchema {
787            columns: test_columns(&[
788                ("tenant_id", DataType::Int32, true),
789                ("id", DataType::Int32, true),
790                ("name", DataType::Varchar, false),
791            ]),
792        };
793
794        let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
795            .unwrap_err();
796
797        assert!(err.contains("dml_batch_id 51 item 1"));
798        assert!(err.contains("failed to decode webhook JSON payload"));
799    }
800
801    #[test]
802    fn test_prepare_dml_batch_payload_rejects_incomplete_composite_pk_delete() {
803        let raw_dml_batch = raw_batch(52, vec![raw_req(Some("delete"), r#"{"tenant_id":1}"#)]);
804        let payload_schema = PayloadSchema::FullSchema {
805            columns: test_columns(&[
806                ("tenant_id", DataType::Int32, true),
807                ("id", DataType::Int32, true),
808                ("name", DataType::Varchar, false),
809            ]),
810        };
811
812        let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
813            .unwrap_err();
814
815        assert!(err.contains("dml_batch_id 52 item 1"));
816        assert!(err.contains("failed to decode webhook JSON payload"));
817    }
818
819    #[test]
820    fn test_prepare_dml_batch_payload_applies_supported_decoder_headers() {
821        let mut headers = HeaderMap::new();
822        headers.insert(
823            "x-rw-webhook-json-timestamp-handling-mode",
824            HeaderValue::from_static("milli"),
825        );
826        headers.insert(
827            "x-rw-webhook-json-time-handling-mode",
828            HeaderValue::from_static("milli"),
829        );
830        headers.insert(
831            "x-rw-webhook-json-bigint-unsigned-handling-mode",
832            HeaderValue::from_static("precise"),
833        );
834
835        let payload_schema = PayloadSchema::FullSchema {
836            columns: test_columns(&[
837                ("id", DataType::Int32, true),
838                ("event_time", DataType::Timestamp, false),
839            ]),
840        };
841        let prepared_timestamp_batch = prepare_dml_batch_payload(
842            &headers,
843            raw_batch(
844                13,
845                vec![raw_req(
846                    Some("upsert"),
847                    r#"{"id":1,"event_time":1712800800123}"#,
848                )],
849            ),
850            &payload_schema,
851        )
852        .unwrap();
853        let timestamp_chunk =
854            StreamChunk::from_protobuf(prepared_timestamp_batch.payload.chunk.as_ref().unwrap())
855                .unwrap();
856        assert!(matches!(
857            timestamp_chunk
858                .rows()
859                .next()
860                .unwrap()
861                .1
862                .datum_at(1)
863                .to_owned_datum(),
864            Some(ScalarImpl::Timestamp(_))
865        ));
866
867        let payload_schema = PayloadSchema::FullSchema {
868            columns: test_columns(&[
869                ("id", DataType::Int32, true),
870                ("event_time", DataType::Time, false),
871            ]),
872        };
873        let prepared_time_batch = prepare_dml_batch_payload(
874            &headers,
875            raw_batch(
876                14,
877                vec![raw_req(Some("upsert"), r#"{"id":2,"event_time":3723123}"#)],
878            ),
879            &payload_schema,
880        )
881        .unwrap();
882        let time_chunk =
883            StreamChunk::from_protobuf(prepared_time_batch.payload.chunk.as_ref().unwrap())
884                .unwrap();
885        assert!(matches!(
886            time_chunk
887                .rows()
888                .next()
889                .unwrap()
890                .1
891                .datum_at(1)
892                .to_owned_datum(),
893            Some(ScalarImpl::Time(_))
894        ));
895
896        let payload_schema = PayloadSchema::FullSchema {
897            columns: test_columns(&[
898                ("id", DataType::Int32, true),
899                ("amount", DataType::Decimal, false),
900            ]),
901        };
902        let prepared_decimal_batch = prepare_dml_batch_payload(
903            &headers,
904            raw_batch(
905                15,
906                vec![raw_req(Some("upsert"), r#"{"id":3,"amount":"AeJA"}"#)],
907            ),
908            &payload_schema,
909        )
910        .unwrap();
911        let decimal_chunk =
912            StreamChunk::from_protobuf(prepared_decimal_batch.payload.chunk.as_ref().unwrap())
913                .unwrap();
914        assert!(matches!(
915            decimal_chunk
916                .rows()
917                .next()
918                .unwrap()
919                .1
920                .datum_at(1)
921                .to_owned_datum(),
922            Some(ScalarImpl::Decimal(_))
923        ));
924    }
925
926    #[test]
927    fn test_prepare_dml_batch_payload_rejects_invalid_decoder_headers() {
928        for (header, value, expected) in [
929            (
930                "x-rw-webhook-json-timestamp-handling-mode",
931                "invalid",
932                "unrecognized `x-rw-webhook-json-timestamp-handling-mode` value",
933            ),
934            (
935                "x-rw-webhook-json-timestamptz-handling-mode",
936                "invalid",
937                "invalid webhook JSON decoder option",
938            ),
939            (
940                "x-rw-webhook-json-time-handling-mode",
941                "invalid",
942                "unrecognized `x-rw-webhook-json-time-handling-mode` value",
943            ),
944            (
945                "x-rw-webhook-json-bigint-unsigned-handling-mode",
946                "invalid",
947                "unrecognized `x-rw-webhook-json-bigint-unsigned-handling-mode` value",
948            ),
949            (
950                "x-rw-webhook-json-handle-toast-columns",
951                "invalid",
952                "unrecognized `x-rw-webhook-json-handle-toast-columns` value",
953            ),
954        ] {
955            let mut headers = HeaderMap::new();
956            headers.insert(header, HeaderValue::from_static(value));
957            let raw_dml_batch = raw_batch(
958                21,
959                vec![raw_req(Some("upsert"), r#"{"id":1,"name":"alice"}"#)],
960            );
961            let payload_schema = PayloadSchema::FullSchema {
962                columns: test_columns(&[
963                    ("id", DataType::Int32, true),
964                    ("name", DataType::Varchar, false),
965                ]),
966            };
967
968            let err =
969                prepare_dml_batch_payload(&headers, raw_dml_batch, &payload_schema).unwrap_err();
970            assert!(err.contains(expected), "{header}");
971        }
972    }
973
974    #[test]
975    fn test_prepare_dml_batch_payload_returns_first_type_error_with_item_index() {
976        let raw_dml_batch = raw_batch(
977            31,
978            vec![
979                raw_req(Some("upsert"), r#"{"id":1,"name":"alice"}"#),
980                raw_req(Some("upsert"), r#"{"id":"not-an-int","name":"bob"}"#),
981            ],
982        );
983        let payload_schema = PayloadSchema::FullSchema {
984            columns: test_columns(&[
985                ("id", DataType::Int32, true),
986                ("name", DataType::Varchar, false),
987            ]),
988        };
989
990        let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
991            .unwrap_err();
992
993        assert!(err.contains("dml_batch_id 31 item 2"));
994        assert!(err.contains("failed to decode webhook JSON payload"));
995    }
996
997    #[test]
998    fn test_prepare_dml_batch_payload_single_jsonb_accepts_scalar_json() {
999        let raw_dml_batch = raw_batch(41, vec![raw_req(Some("upsert"), "123")]);
1000
1001        let prepared_batch = prepare_dml_batch_payload(
1002            &HeaderMap::new(),
1003            raw_dml_batch,
1004            &PayloadSchema::SingleJsonb,
1005        )
1006        .unwrap();
1007        let payload = prepared_batch.payload;
1008        let chunk = StreamChunk::from_protobuf(payload.chunk.as_ref().unwrap()).unwrap();
1009
1010        assert_eq!(prepared_batch.dml_batch_id, 41);
1011        assert_eq!(payload.dml_batch_id, 41);
1012        assert_eq!(chunk.ops(), &[Op::Insert]);
1013        assert!(matches!(
1014            chunk.rows().next().unwrap().1.datum_at(0).to_owned_datum(),
1015            Some(ScalarImpl::Jsonb(_))
1016        ));
1017    }
1018}