1use std::sync::Arc;
49use std::sync::atomic::{AtomicU32, Ordering};
50use std::time::{Duration, SystemTime, UNIX_EPOCH};
51
52use axum::Router;
53use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
54use axum::extract::{Extension, Path};
55use axum::http::HeaderMap;
56use axum::response::IntoResponse;
57use axum::routing::get;
58use futures::SinkExt;
59use futures::stream::StreamExt;
60use jsonbb::Value;
61use risingwave_common::array::{Op, StreamChunk};
62use risingwave_common::row::OwnedRow;
63use risingwave_common::types::{DataType, JsonbVal};
64use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
65use risingwave_pb::task_service::{
66 IngestDmlInitRequest, IngestDmlPayloadRequest, IngestDmlRequest, ingest_dml_request,
67 ingest_dml_response,
68};
69use serde::{Deserialize, Serialize};
70use serde_json::value::RawValue;
71use thiserror_ext::AsReport;
72use tokio::time::timeout;
73use tokio_stream::wrappers::ReceiverStream;
74use tower::ServiceBuilder;
75use tower_http::add_extension::AddExtensionLayer;
76
77use crate::session::SESSION_MANAGER;
78use crate::webhook::payload::{build_json_access_builder, owned_row_from_payload_row};
79use crate::webhook::utils::{authenticate_webhook_payload, header_map_to_json};
80use crate::webhook::{PayloadSchema, acquire_table_info};
81
82const INIT_MESSAGE_TYPE: &str = "init";
83const INGEST_DML_REQUEST_BUFFER_SIZE: usize = 64;
84const WEBSOCKET_INIT_TIMEOUT: Duration = Duration::from_secs(10);
85
86pub struct IngestService {
88 counter: AtomicU32,
89}
90
91impl IngestService {
92 pub fn new() -> Self {
93 Self {
94 counter: AtomicU32::new(0),
95 }
96 }
97}
98
99pub type ServiceRef = Arc<IngestService>;
100
101#[derive(Debug, Deserialize)]
102struct InitRequest {
103 #[serde(rename = "type")]
104 msg_type: String,
105 timestamp: i64,
106}
107
108#[derive(Debug, Deserialize)]
109struct RawDmlBatchRequest {
110 pub dml_batch_id: u64,
111 pub items: Vec<RawDmlRequest>,
112}
113
114#[derive(Debug, Deserialize)]
115pub struct RawDmlRequest {
116 pub op: Option<String>,
117 pub data: Box<RawValue>,
119}
120
121#[derive(Debug, Clone, Copy)]
122enum DmlOp {
123 Upsert,
124 Delete,
125}
126
127#[derive(Debug)]
128pub struct DmlRequest {
129 op: DmlOp,
130 data: Box<RawValue>,
131}
132
133impl TryFrom<RawDmlRequest> for DmlRequest {
134 type Error = String;
135
136 fn try_from(raw: RawDmlRequest) -> Result<Self, Self::Error> {
137 let op = match raw.op.as_deref() {
138 None => {
139 return Err("missing op, expected upsert/delete".to_owned());
140 }
141 Some("upsert" | "insert" | "update") => DmlOp::Upsert,
142 Some("delete") => DmlOp::Delete,
143 Some(other) => {
144 return Err(format!("unknown op '{other}', expected upsert/delete"));
145 }
146 };
147
148 Ok(Self { op, data: raw.data })
149 }
150}
151
152#[derive(Debug, Serialize)]
153#[serde(untagged)]
154pub enum ServerMessage {
155 Ack { ack: u64 },
156 Fatal { fatal: String },
157}
158
159#[derive(Debug)]
160struct PreparedDmlBatch {
161 dml_batch_id: u64,
162 payload: IngestDmlPayloadRequest,
163}
164
165type WsTx = futures::stream::SplitSink<WebSocket, Message>;
166type WsRx = futures::stream::SplitStream<WebSocket>;
167
168pub async fn ws_handler(
169 ws: WebSocketUpgrade,
170 Extension(svc): Extension<ServiceRef>,
171 headers: HeaderMap,
172 Path((database, schema, table)): Path<(String, String, String)>,
173) -> impl IntoResponse {
174 let request_id = svc.counter.fetch_add(1, Ordering::Relaxed);
175 let headers_jsonb = header_map_to_json(&headers);
176 ws.on_upgrade(move |socket| {
177 handle_connection(
178 socket,
179 database,
180 schema,
181 table,
182 request_id,
183 headers,
184 headers_jsonb,
185 )
186 })
187}
188
189async fn handle_connection(
190 socket: WebSocket,
191 database: String,
192 schema: String,
193 table: String,
194 request_id: u32,
195 headers: HeaderMap,
196 headers_jsonb: JsonbVal,
197) {
198 let (mut ws_tx, ws_rx) = socket.split();
199
200 if let Err(e) = try_handle_connection(
201 &mut ws_tx,
202 ws_rx,
203 database,
204 schema,
205 table,
206 request_id,
207 headers,
208 headers_jsonb,
209 )
210 .await
211 {
212 let _ = send_fatal(&mut ws_tx, e).await;
213 }
214}
215
216async fn try_handle_connection(
217 ws_tx: &mut WsTx,
218 mut ws_rx: WsRx,
219 database: String,
220 schema: String,
221 table: String,
222 request_id: u32,
223 headers: HeaderMap,
224 headers_jsonb: JsonbVal,
225) -> Result<(), String> {
226 let table_info = acquire_table_info(request_id, &database, &schema, &table)
227 .await
228 .map_err(|e| format!("table lookup failed: {}", e.as_report()))?;
229
230 let session_mgr = SESSION_MANAGER.get().expect("session manager initialized");
231 let max_clock_skew_ms = session_mgr
232 .env()
233 .frontend_config()
234 .webhook_auth_max_clock_skew_ms;
235 let webhook_source_info = table_info.webhook_source_info;
236 let table_id = table_info.table_id;
237 let table_version_id = table_info.table_version_id;
238 let row_id_index = table_info.row_id_index;
239 let compute_client = table_info.compute_client;
240 let payload_schema = table_info.payload_schema;
241
242 let init_text = match timeout(WEBSOCKET_INIT_TIMEOUT, ws_rx.next()).await {
243 Ok(Some(Ok(Message::Text(text)))) => text,
244 Ok(Some(Ok(Message::Close(_)))) | Ok(None) => return Ok(()),
245 Ok(Some(Ok(_))) => {
246 return Err("the first WebSocket frame must be a text init message".to_owned());
247 }
248 Ok(Some(Err(e))) => {
249 return Err(format!(
250 "failed to read WebSocket init message: {}",
251 e.as_report()
252 ));
253 }
254 Err(_) => {
255 return Err(format!(
256 "timed out waiting for WebSocket init message after {}s",
257 WEBSOCKET_INIT_TIMEOUT.as_secs()
258 ));
259 }
260 };
261
262 if let Err(e) =
263 authenticate_webhook_payload(headers_jsonb, init_text.as_bytes(), &webhook_source_info)
264 .await
265 {
266 return Err(format!("{}", e.as_report()));
267 }
268
269 parse_and_validate_init_request(&init_text, max_clock_skew_ms)?;
270
271 let (ingest_req_tx, ingest_req_rx) = tokio::sync::mpsc::channel(INGEST_DML_REQUEST_BUFFER_SIZE);
272 if ingest_req_tx
273 .send(IngestDmlRequest {
274 request: Some(ingest_dml_request::Request::Init(IngestDmlInitRequest {
275 table_id,
276 table_version_id,
277 request_id,
278 row_id_index,
279 })),
280 })
281 .await
282 .is_err()
283 {
284 return Err("failed to enqueue init request for ingest stream".to_owned());
285 }
286
287 let mut ingest_resp_stream = compute_client
288 .ingest_dml(ReceiverStream::new(ingest_req_rx))
289 .await
290 .map_err(|e| format!("failed to open ingest stream: {}", e.as_report()))?;
291
292 match ingest_resp_stream.message().await {
293 Ok(Some(resp)) => match resp.response {
294 Some(ingest_dml_response::Response::Init(_)) => {}
295 _ => {
296 return Err("unexpected init response from ingest stream".to_owned());
297 }
298 },
299 Ok(None) => return Err("ingest stream closed during init".to_owned()),
300 Err(e) => {
301 return Err(format!("ingest stream init error: {}", e.as_report()));
302 }
303 }
304
305 let mut last_seen_dml_batch_id = 0_u64;
306 let mut last_forwarded_dml_batch_id = 0_u64;
307 let mut last_acked_dml_batch_id = 0_u64;
308
309 loop {
310 tokio::select! {
311 biased;
312 ingest_resp = ingest_resp_stream.message() => {
313 match ingest_resp {
314 Ok(Some(resp)) => match resp.response {
315 Some(ingest_dml_response::Response::Ack(ack)) => {
316 if ack.dml_batch_id <= last_acked_dml_batch_id
317 || ack.dml_batch_id > last_forwarded_dml_batch_id
318 {
319 return Err(format!(
320 "unexpected ack for dml_batch_id {}",
321 ack.dml_batch_id
322 ));
323 }
324 last_acked_dml_batch_id = ack.dml_batch_id;
325 send_server_message(
326 ws_tx,
327 ServerMessage::Ack {
328 ack: ack.dml_batch_id,
329 },
330 )
331 .await
332 .map_err(|_| "websocket connection closed while sending ack".to_owned())?;
333 }
334 Some(ingest_dml_response::Response::Init(_)) => {
335 return Err("unexpected extra init response from ingest stream".to_owned());
336 }
337 None => return Err("empty response from ingest stream".to_owned()),
338 },
339 Ok(None) => return Err("ingest stream closed".to_owned()),
340 Err(e) => return Err(format!("ingest stream error: {}", e.as_report())),
341 }
342 }
343
344 ws_msg = ws_rx.next() => {
345 let text = match ws_msg {
346 Some(Ok(Message::Text(text))) => text,
347 Some(Ok(Message::Close(_))) | None => break,
348 Some(Ok(_)) => continue,
349 Some(Err(e)) => {
350 return Err(format!(
351 "failed to read WebSocket message: {}",
352 e.as_report()
353 ));
354 }
355 };
356
357 let raw_dml_batch = match serde_json::from_str::<RawDmlBatchRequest>(&text) {
358 Ok(batch) => batch,
359 Err(e) => return Err(format!("malformed payload: {}", e.as_report())),
360 };
361
362 last_seen_dml_batch_id = validate_monotonic_dml_batch_id(
363 raw_dml_batch.dml_batch_id,
364 last_seen_dml_batch_id,
365 )?;
366
367 if raw_dml_batch.items.is_empty() {
368 send_server_message(
369 ws_tx,
370 ServerMessage::Ack {
371 ack: raw_dml_batch.dml_batch_id,
372 },
373 )
374 .await
375 .map_err(|_| "websocket connection closed while sending ack".to_owned())?;
376 continue;
377 }
378
379 let prepared_batch = prepare_dml_batch_payload(
380 &headers,
381 raw_dml_batch,
382 &payload_schema,
383 )
384 .map_err(|e| format!("failed to prepare DML batch: {e}"))?;
385
386 let PreparedDmlBatch {
387 dml_batch_id,
388 payload,
389 } = prepared_batch;
390
391 ingest_req_tx
392 .send(IngestDmlRequest {
393 request: Some(ingest_dml_request::Request::Payload(payload)),
394 })
395 .await
396 .map_err(|_| "ingest stream request channel closed".to_owned())?;
397 last_forwarded_dml_batch_id = dml_batch_id;
398 }
399 }
400 }
401
402 Ok(())
403}
404
405fn parse_and_validate_init_request(
406 text: &str,
407 max_clock_skew_ms: u64,
408) -> Result<InitRequest, String> {
409 let init_req: InitRequest = serde_json::from_str(text)
410 .map_err(|e| format!("malformed init message: {}", e.as_report()))?;
411
412 if init_req.msg_type != INIT_MESSAGE_TYPE {
413 return Err(format!(
414 "invalid init message type '{}', expected '{}'",
415 init_req.msg_type, INIT_MESSAGE_TYPE
416 ));
417 }
418
419 validate_timestamp_skew(init_req.timestamp, max_clock_skew_ms)?;
420 Ok(init_req)
421}
422
423fn validate_timestamp_skew(timestamp_ms: i64, max_clock_skew_ms: u64) -> Result<(), String> {
424 if timestamp_ms < 0 {
425 return Err("timestamp must be a non-negative epoch millisecond".to_owned());
426 }
427
428 let now_ms = SystemTime::now()
429 .duration_since(UNIX_EPOCH)
430 .unwrap_or(Duration::ZERO)
431 .as_millis() as i128;
432 let diff_ms = (now_ms - i128::from(timestamp_ms)).abs();
433
434 if diff_ms > i128::from(max_clock_skew_ms) {
435 return Err(format!(
436 "timestamp skew {}ms exceeds the allowed {}ms window",
437 diff_ms, max_clock_skew_ms
438 ));
439 }
440
441 Ok(())
442}
443
444fn validate_monotonic_dml_batch_id(
445 dml_batch_id: u64,
446 last_seen_dml_batch_id: u64,
447) -> Result<u64, String> {
448 if dml_batch_id <= last_seen_dml_batch_id {
449 return Err(format!(
450 "dml_batch_id must increase monotonically: received {} after {}",
451 dml_batch_id, last_seen_dml_batch_id
452 ));
453 }
454 Ok(dml_batch_id)
455}
456
457fn prepare_dml_batch_payload(
458 headers: &HeaderMap,
459 raw_dml_batch: RawDmlBatchRequest,
460 payload_schema: &PayloadSchema,
461) -> Result<PreparedDmlBatch, String> {
462 let dml_batch_id = raw_dml_batch.dml_batch_id;
463 let raw_dml_reqs = raw_dml_batch.items;
464
465 match payload_schema {
466 PayloadSchema::SingleJsonb => {
467 let mut chunk_builder = DataChunkBuilder::new(
468 vec![DataType::Jsonb],
469 raw_dml_reqs.len().saturating_add(1).max(1),
470 );
471 let mut ops = Vec::with_capacity(raw_dml_reqs.len());
472
473 for (index, raw_dml_req) in raw_dml_reqs.into_iter().enumerate() {
474 let item_index = index + 1;
475 let dml_req = DmlRequest::try_from(raw_dml_req)
476 .map_err(|e| format!("dml_batch_id {dml_batch_id} item {item_index}: {e}"))?;
477
478 let row = Value::from_text(dml_req.data.get().as_bytes())
479 .map(|json_value| OwnedRow::new(vec![Some(JsonbVal::from(json_value).into())]))
480 .map_err(|e| {
481 format!(
482 "dml_batch_id {dml_batch_id} item {item_index}: Failed to parse body: {}",
483 e.as_report()
484 )
485 })?;
486
487 let output = chunk_builder.append_one_row(row);
488 debug_assert!(output.is_none());
489 ops.push(match dml_req.op {
490 DmlOp::Upsert => Op::Insert,
491 DmlOp::Delete => Op::Delete,
492 });
493 }
494
495 let data_chunk = chunk_builder
496 .consume_all()
497 .expect("buffered rows should produce a chunk");
498 let chunk = StreamChunk::from_parts(ops, data_chunk);
499 let payload = IngestDmlPayloadRequest {
500 dml_batch_id,
501 chunk: Some(chunk.to_protobuf()),
502 };
503
504 Ok(PreparedDmlBatch {
505 dml_batch_id,
506 payload,
507 })
508 }
509 PayloadSchema::FullSchema { columns } => {
510 let mut access_builder =
511 build_json_access_builder(headers).map_err(|e| format!("{}", e.as_report()))?;
512 let mut chunk_builder = DataChunkBuilder::new(
513 columns
514 .iter()
515 .map(|column| column.data_type.clone())
516 .collect(),
517 raw_dml_reqs.len().saturating_add(1).max(1),
518 );
519 let mut ops = Vec::with_capacity(raw_dml_reqs.len());
520
521 for (index, raw_dml_req) in raw_dml_reqs.into_iter().enumerate() {
522 let item_index = index + 1;
523 let dml_req = DmlRequest::try_from(raw_dml_req)
524 .map_err(|e| format!("dml_batch_id {dml_batch_id} item {item_index}: {e}"))?;
525
526 let row = owned_row_from_payload_row(
527 &mut access_builder,
528 columns,
529 dml_req.data.get().as_bytes(),
530 )
531 .map_err(|e| {
532 format!(
533 "dml_batch_id {dml_batch_id} item {item_index}: {}",
534 e.as_report()
535 )
536 })?;
537
538 let output = chunk_builder.append_one_row(row);
539 debug_assert!(output.is_none());
540 ops.push(match dml_req.op {
541 DmlOp::Upsert => Op::Insert,
542 DmlOp::Delete => Op::Delete,
543 });
544 }
545
546 let data_chunk = chunk_builder
547 .consume_all()
548 .expect("buffered rows should produce a chunk");
549 let chunk = StreamChunk::from_parts(ops, data_chunk);
550 let payload = IngestDmlPayloadRequest {
551 dml_batch_id,
552 chunk: Some(chunk.to_protobuf()),
553 };
554
555 Ok(PreparedDmlBatch {
556 dml_batch_id,
557 payload,
558 })
559 }
560 }
561}
562
563async fn send_server_message(ws_tx: &mut WsTx, msg: ServerMessage) -> Result<(), String> {
564 let text = serde_json::to_string(&msg).map_err(|e| format!("{}", e.as_report()))?;
565 ws_tx
566 .send(Message::Text(text.into()))
567 .await
568 .map_err(|e| format!("{}", e.as_report()))
569}
570
571async fn send_fatal(ws_tx: &mut WsTx, fatal: String) -> Result<(), String> {
572 send_server_message(ws_tx, ServerMessage::Fatal { fatal }).await
573}
574
575pub fn build_router(svc: ServiceRef) -> Router {
576 Router::new()
577 .route("/{database}/{schema}/{table}", get(ws_handler))
578 .layer(
579 ServiceBuilder::new()
580 .layer(AddExtensionLayer::new(svc))
581 .into_inner(),
582 )
583}
584
585#[cfg(test)]
586mod tests {
587 use axum::http::{HeaderMap, HeaderValue};
588 use risingwave_common::row::Row;
589 use risingwave_common::types::{DataType, ScalarImpl, ToOwnedDatum};
590
591 use super::*;
592 use crate::webhook::WebhookTableColumnDesc;
593
594 fn raw_json(text: &str) -> Box<RawValue> {
595 serde_json::from_str(text).unwrap()
596 }
597
598 fn raw_req(op: Option<&str>, data: &str) -> RawDmlRequest {
599 RawDmlRequest {
600 op: op.map(str::to_owned),
601 data: raw_json(data),
602 }
603 }
604
605 fn raw_batch(dml_batch_id: u64, items: Vec<RawDmlRequest>) -> RawDmlBatchRequest {
606 RawDmlBatchRequest {
607 dml_batch_id,
608 items,
609 }
610 }
611
612 fn test_columns(columns: &[(&str, DataType, bool)]) -> Vec<WebhookTableColumnDesc> {
613 columns
614 .iter()
615 .map(|(name, data_type, is_pk)| WebhookTableColumnDesc {
616 name: (*name).to_owned(),
617 data_type: data_type.clone(),
618 is_pk: *is_pk,
619 })
620 .collect()
621 }
622
623 #[test]
624 fn test_validate_monotonic_dml_batch_id() {
625 assert_eq!(validate_monotonic_dml_batch_id(5, 1).unwrap(), 5);
626 }
627
628 #[test]
629 fn test_validate_monotonic_dml_batch_id_rejects_non_monotonic_sequence() {
630 let err = validate_monotonic_dml_batch_id(3, 3).unwrap_err();
631 assert!(err.contains("dml_batch_id must increase monotonically"));
632 }
633
634 #[test]
635 fn test_parse_and_validate_init_request() {
636 let now_ms = SystemTime::now()
637 .duration_since(UNIX_EPOCH)
638 .unwrap_or(Duration::ZERO)
639 .as_millis() as i64;
640 let init = format!(r#"{{"type":"init","timestamp":{now_ms}}}"#);
641 assert_eq!(
642 parse_and_validate_init_request(&init, 300_000)
643 .unwrap()
644 .timestamp,
645 now_ms
646 );
647 }
648
649 #[test]
650 fn test_parse_and_validate_init_request_rejects_stale_timestamp() {
651 let init = r#"{"type":"init","timestamp":0}"#;
652 let err = parse_and_validate_init_request(init, 1).unwrap_err();
653 assert!(err.contains("timestamp skew"));
654 }
655
656 #[test]
657 fn test_dml_request_requires_op() {
658 let err = DmlRequest::try_from(raw_req(None, r#"{"id":1}"#)).unwrap_err();
659 assert_eq!(err, "missing op, expected upsert/delete");
660 }
661
662 #[test]
663 fn test_empty_batch_is_valid_json_protocol_input() {
664 let raw_dml_batch = raw_batch(42, vec![]);
665 assert_eq!(
666 validate_monotonic_dml_batch_id(raw_dml_batch.dml_batch_id, 41).unwrap(),
667 42
668 );
669 assert!(raw_dml_batch.items.is_empty());
670 }
671
672 #[test]
673 fn test_prepare_dml_batch_payload_builds_single_chunk_for_batch() {
674 let raw_dml_batch = raw_batch(
675 10,
676 vec![
677 raw_req(
678 Some("upsert"),
679 r#"{"id":1,"price":"19.99","created_at":"2026-04-15 10:00:00","name":"alice"}"#,
680 ),
681 raw_req(
682 Some("delete"),
683 r#"{"id":2,"price":"29.99","created_at":"2026-04-16 10:00:00","name":"bob"}"#,
684 ),
685 ],
686 );
687 let payload_schema = PayloadSchema::FullSchema {
688 columns: test_columns(&[
689 ("id", DataType::Int32, true),
690 ("price", DataType::Decimal, false),
691 ("created_at", DataType::Timestamp, false),
692 ("name", DataType::Varchar, false),
693 ]),
694 };
695
696 let prepared_batch =
697 prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema).unwrap();
698
699 assert_eq!(prepared_batch.dml_batch_id, 10);
700
701 let payload = prepared_batch.payload;
702 assert_eq!(payload.dml_batch_id, 10);
703 let chunk = StreamChunk::from_protobuf(payload.chunk.as_ref().unwrap()).unwrap();
704
705 assert_eq!(chunk.ops(), &[Op::Insert, Op::Delete]);
706 let mut rows = chunk.rows();
707 let row = rows.next().unwrap().1;
708 assert_eq!(row.datum_at(0).to_owned_datum(), Some(ScalarImpl::Int32(1)));
709 assert!(matches!(
710 row.datum_at(1).to_owned_datum(),
711 Some(ScalarImpl::Decimal(_))
712 ));
713 assert!(matches!(
714 row.datum_at(2).to_owned_datum(),
715 Some(ScalarImpl::Timestamp(_))
716 ));
717 assert_eq!(
718 row.datum_at(3).to_owned_datum(),
719 Some(ScalarImpl::Utf8("alice".into()))
720 );
721
722 let row = rows.next().unwrap().1;
723 assert_eq!(row.datum_at(0).to_owned_datum(), Some(ScalarImpl::Int32(2)));
724 }
725
726 #[test]
727 fn test_prepare_dml_batch_payload_returns_first_item_error() {
728 let raw_dml_batch = raw_batch(
729 11,
730 vec![
731 raw_req(Some("delete"), r#"{"id":1,"name":"alice"}"#),
732 raw_req(Some("upsert"), r#"{"name":"bob"}"#),
733 ],
734 );
735 let payload_schema = PayloadSchema::FullSchema {
736 columns: test_columns(&[
737 ("id", DataType::Int32, true),
738 ("name", DataType::Varchar, false),
739 ]),
740 };
741
742 let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
743 .unwrap_err();
744
745 assert!(err.contains("dml_batch_id 11 item 2"));
746 assert!(err.contains("failed to decode webhook JSON payload"));
747 }
748
749 #[test]
750 fn test_prepare_dml_batch_payload_allows_delete_with_only_pk() {
751 let raw_dml_batch = raw_batch(
752 12,
753 vec![
754 raw_req(Some("upsert"), r#"{"id":7,"name":"alice"}"#),
755 raw_req(Some("delete"), r#"{"id":7}"#),
756 ],
757 );
758 let payload_schema = PayloadSchema::FullSchema {
759 columns: test_columns(&[
760 ("id", DataType::Int32, true),
761 ("name", DataType::Varchar, false),
762 ]),
763 };
764
765 let prepared_batch =
766 prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema).unwrap();
767 let chunk =
768 StreamChunk::from_protobuf(prepared_batch.payload.chunk.as_ref().unwrap()).unwrap();
769
770 assert_eq!(prepared_batch.dml_batch_id, 12);
771 assert_eq!(chunk.ops(), &[Op::Insert, Op::Delete]);
772 let mut rows = chunk.rows();
773 assert_eq!(
774 rows.next().unwrap().1.datum_at(1).to_owned_datum(),
775 Some(ScalarImpl::Utf8("alice".into()))
776 );
777 assert_eq!(rows.next().unwrap().1.datum_at(1).to_owned_datum(), None);
778 }
779
780 #[test]
781 fn test_prepare_dml_batch_payload_rejects_incomplete_composite_pk_insert() {
782 let raw_dml_batch = raw_batch(
783 51,
784 vec![raw_req(Some("upsert"), r#"{"id":1,"name":"alice"}"#)],
785 );
786 let payload_schema = PayloadSchema::FullSchema {
787 columns: test_columns(&[
788 ("tenant_id", DataType::Int32, true),
789 ("id", DataType::Int32, true),
790 ("name", DataType::Varchar, false),
791 ]),
792 };
793
794 let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
795 .unwrap_err();
796
797 assert!(err.contains("dml_batch_id 51 item 1"));
798 assert!(err.contains("failed to decode webhook JSON payload"));
799 }
800
801 #[test]
802 fn test_prepare_dml_batch_payload_rejects_incomplete_composite_pk_delete() {
803 let raw_dml_batch = raw_batch(52, vec![raw_req(Some("delete"), r#"{"tenant_id":1}"#)]);
804 let payload_schema = PayloadSchema::FullSchema {
805 columns: test_columns(&[
806 ("tenant_id", DataType::Int32, true),
807 ("id", DataType::Int32, true),
808 ("name", DataType::Varchar, false),
809 ]),
810 };
811
812 let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
813 .unwrap_err();
814
815 assert!(err.contains("dml_batch_id 52 item 1"));
816 assert!(err.contains("failed to decode webhook JSON payload"));
817 }
818
819 #[test]
820 fn test_prepare_dml_batch_payload_applies_supported_decoder_headers() {
821 let mut headers = HeaderMap::new();
822 headers.insert(
823 "x-rw-webhook-json-timestamp-handling-mode",
824 HeaderValue::from_static("milli"),
825 );
826 headers.insert(
827 "x-rw-webhook-json-time-handling-mode",
828 HeaderValue::from_static("milli"),
829 );
830 headers.insert(
831 "x-rw-webhook-json-bigint-unsigned-handling-mode",
832 HeaderValue::from_static("precise"),
833 );
834
835 let payload_schema = PayloadSchema::FullSchema {
836 columns: test_columns(&[
837 ("id", DataType::Int32, true),
838 ("event_time", DataType::Timestamp, false),
839 ]),
840 };
841 let prepared_timestamp_batch = prepare_dml_batch_payload(
842 &headers,
843 raw_batch(
844 13,
845 vec![raw_req(
846 Some("upsert"),
847 r#"{"id":1,"event_time":1712800800123}"#,
848 )],
849 ),
850 &payload_schema,
851 )
852 .unwrap();
853 let timestamp_chunk =
854 StreamChunk::from_protobuf(prepared_timestamp_batch.payload.chunk.as_ref().unwrap())
855 .unwrap();
856 assert!(matches!(
857 timestamp_chunk
858 .rows()
859 .next()
860 .unwrap()
861 .1
862 .datum_at(1)
863 .to_owned_datum(),
864 Some(ScalarImpl::Timestamp(_))
865 ));
866
867 let payload_schema = PayloadSchema::FullSchema {
868 columns: test_columns(&[
869 ("id", DataType::Int32, true),
870 ("event_time", DataType::Time, false),
871 ]),
872 };
873 let prepared_time_batch = prepare_dml_batch_payload(
874 &headers,
875 raw_batch(
876 14,
877 vec![raw_req(Some("upsert"), r#"{"id":2,"event_time":3723123}"#)],
878 ),
879 &payload_schema,
880 )
881 .unwrap();
882 let time_chunk =
883 StreamChunk::from_protobuf(prepared_time_batch.payload.chunk.as_ref().unwrap())
884 .unwrap();
885 assert!(matches!(
886 time_chunk
887 .rows()
888 .next()
889 .unwrap()
890 .1
891 .datum_at(1)
892 .to_owned_datum(),
893 Some(ScalarImpl::Time(_))
894 ));
895
896 let payload_schema = PayloadSchema::FullSchema {
897 columns: test_columns(&[
898 ("id", DataType::Int32, true),
899 ("amount", DataType::Decimal, false),
900 ]),
901 };
902 let prepared_decimal_batch = prepare_dml_batch_payload(
903 &headers,
904 raw_batch(
905 15,
906 vec![raw_req(Some("upsert"), r#"{"id":3,"amount":"AeJA"}"#)],
907 ),
908 &payload_schema,
909 )
910 .unwrap();
911 let decimal_chunk =
912 StreamChunk::from_protobuf(prepared_decimal_batch.payload.chunk.as_ref().unwrap())
913 .unwrap();
914 assert!(matches!(
915 decimal_chunk
916 .rows()
917 .next()
918 .unwrap()
919 .1
920 .datum_at(1)
921 .to_owned_datum(),
922 Some(ScalarImpl::Decimal(_))
923 ));
924 }
925
926 #[test]
927 fn test_prepare_dml_batch_payload_rejects_invalid_decoder_headers() {
928 for (header, value, expected) in [
929 (
930 "x-rw-webhook-json-timestamp-handling-mode",
931 "invalid",
932 "unrecognized `x-rw-webhook-json-timestamp-handling-mode` value",
933 ),
934 (
935 "x-rw-webhook-json-timestamptz-handling-mode",
936 "invalid",
937 "invalid webhook JSON decoder option",
938 ),
939 (
940 "x-rw-webhook-json-time-handling-mode",
941 "invalid",
942 "unrecognized `x-rw-webhook-json-time-handling-mode` value",
943 ),
944 (
945 "x-rw-webhook-json-bigint-unsigned-handling-mode",
946 "invalid",
947 "unrecognized `x-rw-webhook-json-bigint-unsigned-handling-mode` value",
948 ),
949 (
950 "x-rw-webhook-json-handle-toast-columns",
951 "invalid",
952 "unrecognized `x-rw-webhook-json-handle-toast-columns` value",
953 ),
954 ] {
955 let mut headers = HeaderMap::new();
956 headers.insert(header, HeaderValue::from_static(value));
957 let raw_dml_batch = raw_batch(
958 21,
959 vec![raw_req(Some("upsert"), r#"{"id":1,"name":"alice"}"#)],
960 );
961 let payload_schema = PayloadSchema::FullSchema {
962 columns: test_columns(&[
963 ("id", DataType::Int32, true),
964 ("name", DataType::Varchar, false),
965 ]),
966 };
967
968 let err =
969 prepare_dml_batch_payload(&headers, raw_dml_batch, &payload_schema).unwrap_err();
970 assert!(err.contains(expected), "{header}");
971 }
972 }
973
974 #[test]
975 fn test_prepare_dml_batch_payload_returns_first_type_error_with_item_index() {
976 let raw_dml_batch = raw_batch(
977 31,
978 vec![
979 raw_req(Some("upsert"), r#"{"id":1,"name":"alice"}"#),
980 raw_req(Some("upsert"), r#"{"id":"not-an-int","name":"bob"}"#),
981 ],
982 );
983 let payload_schema = PayloadSchema::FullSchema {
984 columns: test_columns(&[
985 ("id", DataType::Int32, true),
986 ("name", DataType::Varchar, false),
987 ]),
988 };
989
990 let err = prepare_dml_batch_payload(&HeaderMap::new(), raw_dml_batch, &payload_schema)
991 .unwrap_err();
992
993 assert!(err.contains("dml_batch_id 31 item 2"));
994 assert!(err.contains("failed to decode webhook JSON payload"));
995 }
996
997 #[test]
998 fn test_prepare_dml_batch_payload_single_jsonb_accepts_scalar_json() {
999 let raw_dml_batch = raw_batch(41, vec![raw_req(Some("upsert"), "123")]);
1000
1001 let prepared_batch = prepare_dml_batch_payload(
1002 &HeaderMap::new(),
1003 raw_dml_batch,
1004 &PayloadSchema::SingleJsonb,
1005 )
1006 .unwrap();
1007 let payload = prepared_batch.payload;
1008 let chunk = StreamChunk::from_protobuf(payload.chunk.as_ref().unwrap()).unwrap();
1009
1010 assert_eq!(prepared_batch.dml_batch_id, 41);
1011 assert_eq!(payload.dml_batch_id, 41);
1012 assert_eq!(chunk.ops(), &[Op::Insert]);
1013 assert!(matches!(
1014 chunk.rows().next().unwrap().1.datum_at(0).to_owned_datum(),
1015 Some(ScalarImpl::Jsonb(_))
1016 ));
1017 }
1018}