1use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::{Context, anyhow};
18use itertools::Itertools;
19use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
20use risingwave_common::array::{Op, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use serde::{Deserialize, Serialize};
25use serde_json::{Map, Value};
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use with_options::WithOptions;
29
30use crate::enforce_secret::EnforceSecret;
31use crate::sink::encoder::{JsonEncoder, RowEncoder};
32use crate::sink::log_store::DeliveryFutureManagerAddFuture;
33use crate::sink::writer::{
34 AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter, AsyncTruncateSinkWriterExt,
35};
36use crate::sink::{Result, Sink, SinkError, SinkParam, SinkWriterParam};
37
38pub const TURBOPUFFER_SINK: &str = "turbopuffer";
39
40#[serde_as]
41#[derive(Clone, Debug, Deserialize, WithOptions)]
42pub struct TurbopufferConfig {
43 pub base_url: String,
44 pub namespace: Option<String>,
45 pub namespace_column: Option<String>,
46 pub api_key: String,
47 pub distance_metric: Option<String>,
48 #[serde_as(as = "Option<DisplayFromStr>")]
49 pub disable_backpressure: Option<bool>,
50 pub full_text_search_columns: Option<String>,
51 pub filterable_columns: Option<String>,
52 pub r#type: String, }
54
55impl EnforceSecret for TurbopufferConfig {
56 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
57 "api_key",
58 };
59}
60
61impl TurbopufferConfig {
62 fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
63 serde_json::from_value::<TurbopufferConfig>(
64 serde_json::to_value(values).expect("serialize sink properties"),
65 )
66 .map_err(|e| SinkError::Config(anyhow!(e)))
67 }
68}
69
70#[derive(Clone, Debug)]
71enum TurbopufferNamespace {
72 Static(String),
73 Dynamic { index: usize },
74}
75
76#[derive(Clone, Debug)]
77pub struct TurbopufferSink {
78 config: TurbopufferConfig,
79 schema: Schema,
80 is_append_only: bool,
81 pk_index: usize,
82 namespace: TurbopufferNamespace,
83 attribute_indices: Vec<usize>,
84 generated_schema: Value,
85}
86
87impl EnforceSecret for TurbopufferSink {
88 fn enforce_secret<'a>(
89 prop_iter: impl Iterator<Item = &'a str>,
90 ) -> crate::error::ConnectorResult<()> {
91 for prop in prop_iter {
92 TurbopufferConfig::enforce_one(prop)?;
93 }
94 Ok(())
95 }
96}
97
98impl TryFrom<SinkParam> for TurbopufferSink {
99 type Error = SinkError;
100
101 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
102 let schema = param.schema();
103 let is_append_only = param.sink_type.is_append_only();
104 let pk_indices = param.downstream_pk_or_empty();
105 let [pk_index] = pk_indices.as_slice() else {
106 return Err(SinkError::Config(anyhow!(
107 "Turbopuffer sink requires exactly one primary_key column"
108 )));
109 };
110 let pk_index = *pk_index;
111 match schema[pk_index].data_type() {
112 DataType::Int16
113 | DataType::Int32
114 | DataType::Int64
115 | DataType::Serial
116 | DataType::Varchar => {}
117 data_type => {
118 return Err(SinkError::Config(anyhow!(
119 "Turbopuffer document id column must be an integer or varchar, got {:?}",
120 data_type
121 )));
122 }
123 };
124 let config = TurbopufferConfig::from_btreemap(param.properties)?;
125
126 let namespace = match (&config.namespace, &config.namespace_column) {
127 (Some(namespace), None) => {
128 validate_namespace(namespace)?;
129 TurbopufferNamespace::Static(namespace.clone())
130 }
131 (None, Some(namespace_column)) => {
132 let index = schema
133 .fields()
134 .iter()
135 .position(|field| field.name == *namespace_column)
136 .ok_or_else(|| {
137 SinkError::Config(anyhow!(
138 "Turbopuffer namespace_column '{}' not found in sink schema",
139 namespace_column
140 ))
141 })?;
142 if schema[index].data_type != DataType::Varchar {
143 return Err(SinkError::Config(anyhow!(
144 "Turbopuffer namespace_column must be varchar, got {:?}",
145 schema[index].data_type
146 )));
147 }
148 TurbopufferNamespace::Dynamic { index }
149 }
150 (Some(_), Some(_)) => {
151 return Err(SinkError::Config(anyhow!(
152 "Turbopuffer sink requires only one of namespace or namespace_column"
153 )));
154 }
155 (None, None) => {
156 return Err(SinkError::Config(anyhow!(
157 "Turbopuffer sink requires either namespace or namespace_column"
158 )));
159 }
160 };
161
162 let excluded_indices = match &namespace {
165 TurbopufferNamespace::Static(_) => HashSet::from([pk_index]),
166 TurbopufferNamespace::Dynamic { index } => HashSet::from([pk_index, *index]),
167 };
168 let attribute_indices = (0..schema.len())
169 .filter(|idx| !excluded_indices.contains(idx))
170 .collect_vec();
171 for index in &attribute_indices {
172 if schema[*index].name == "id" {
173 return Err(SinkError::Config(anyhow!(
174 "Turbopuffer attribute column must not be named id"
175 )));
176 }
177 }
178 let full_text_search_columns = parse_column_selection(
179 config.full_text_search_columns.as_deref(),
180 &schema,
181 &attribute_indices,
182 )?;
183 let filterable_columns = parse_column_selection(
184 config.filterable_columns.as_deref(),
185 &schema,
186 &attribute_indices,
187 )?;
188 let has_vector = attribute_indices
189 .iter()
190 .any(|idx| matches!(schema[*idx].data_type, DataType::Vector(_)));
191 if has_vector && config.distance_metric.is_none() {
192 return Err(SinkError::Config(anyhow!(
193 "Turbopuffer sink requires distance_metric when sink schema contains vector columns"
194 )));
195 }
196 let generated_schema = build_turbopuffer_schema(
200 &schema,
201 &attribute_indices,
202 &full_text_search_columns,
203 &filterable_columns,
204 )?;
205
206 Ok(Self {
207 config,
208 schema,
209 is_append_only,
210 pk_index,
211 namespace,
212 attribute_indices,
213 generated_schema,
214 })
215 }
216}
217
218impl Sink for TurbopufferSink {
219 type LogSinker = AsyncTruncateLogSinkerOf<TurbopufferSinkWriter>;
220
221 const SINK_NAME: &'static str = TURBOPUFFER_SINK;
222
223 async fn validate(&self) -> Result<()> {
224 Ok(())
225 }
226
227 async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
228 Ok(TurbopufferSinkWriter::new(
229 self.config.clone(),
230 self.schema.clone(),
231 self.is_append_only,
232 self.pk_index,
233 self.namespace.clone(),
234 self.attribute_indices.clone(),
235 self.generated_schema.clone(),
236 )?
237 .into_log_sinker(usize::MAX))
238 }
239}
240
241pub struct TurbopufferSinkWriter {
242 client: reqwest::Client,
243 base_url: String,
244 distance_metric: Option<String>,
245 disable_backpressure: Option<bool>,
246 schema: Value,
247 is_append_only: bool,
248 pk_index: usize,
249 namespace: TurbopufferNamespace,
250 row_encoder: JsonEncoder,
251}
252
253impl TurbopufferSinkWriter {
254 fn new(
255 config: TurbopufferConfig,
256 schema: Schema,
257 is_append_only: bool,
258 pk_index: usize,
259 namespace: TurbopufferNamespace,
260 attribute_indices: Vec<usize>,
261 generated_schema: Value,
262 ) -> Result<Self> {
263 let mut header_map = HeaderMap::new();
264 header_map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
265 let authorization = format!("Bearer {}", config.api_key);
266 header_map.insert(
267 AUTHORIZATION,
268 authorization
269 .parse()
270 .context("invalid turbopuffer api_key")
271 .map_err(SinkError::Config)?,
272 );
273 let client = reqwest::Client::builder()
274 .default_headers(header_map)
275 .build()
276 .context("failed to build turbopuffer HTTP client")
277 .map_err(SinkError::Http)?;
278 let base_url = config
279 .base_url
280 .parse::<reqwest::Url>()
281 .context("invalid turbopuffer base_url")
282 .map_err(SinkError::Config)?
283 .to_string()
284 .trim_end_matches('/')
285 .to_owned();
286 let row_encoder = JsonEncoder::new_with_turbopuffer(schema, Some(attribute_indices));
287 Ok(Self {
288 client,
289 base_url,
290 distance_metric: config.distance_metric,
291 disable_backpressure: config.disable_backpressure,
292 schema: generated_schema,
293 is_append_only,
294 pk_index,
295 namespace,
296 row_encoder,
297 })
298 }
299
300 fn url_for_row(&self, row: &impl Row) -> Result<String> {
301 match &self.namespace {
302 TurbopufferNamespace::Static(namespace) => Ok(format!(
303 "{}/v2/namespaces/{}",
304 self.base_url,
305 namespace.as_str()
306 )),
307 TurbopufferNamespace::Dynamic { index } => {
308 let namespace = match row.datum_at(*index) {
309 Some(ScalarRefImpl::Utf8(namespace)) => namespace,
310 None => {
311 return Err(SinkError::Http(anyhow!(
312 "Turbopuffer namespace_column cannot be null"
313 )));
314 }
315 Some(_) => {
316 return Err(SinkError::Http(anyhow!(
317 "unexpected namespace_column type, expected varchar"
318 )));
319 }
320 };
321 validate_namespace(namespace)?;
322 Ok(format!("{}/v2/namespaces/{}", self.base_url, namespace))
323 }
324 }
325 }
326
327 fn id_for_row(&self, row: &impl Row) -> Result<DocumentId> {
330 let datum = row.datum_at(self.pk_index).ok_or_else(|| {
331 SinkError::Http(anyhow!("Turbopuffer document id column cannot be null"))
332 })?;
333 match datum {
334 ScalarRefImpl::Int16(value) => Ok(document_id_from_i64(value as i64)),
335 ScalarRefImpl::Int32(value) => Ok(document_id_from_i64(value as i64)),
336 ScalarRefImpl::Int64(value) => Ok(document_id_from_i64(value)),
337 ScalarRefImpl::Serial(value) => Ok(document_id_from_i64(value.into_inner())),
338 ScalarRefImpl::Utf8(value) => {
339 if value.len() > 64 {
340 return Err(SinkError::Http(anyhow!(
341 "Turbopuffer string document id exceeds 64 bytes"
342 )));
343 }
344 Ok(DocumentId::String(value.to_owned()))
345 }
346 _ => Err(SinkError::Http(anyhow!(
347 "Turbopuffer document id column must be an integer or varchar"
348 ))),
349 }
350 }
351
352 fn upsert_row(&self, row: &impl Row, id: DocumentId) -> Result<Map<String, Value>> {
353 let mut value = self.row_encoder.encode(row)?;
354 value.insert(
355 "id".to_owned(),
356 serde_json::to_value(id).expect("serialize document id"),
357 );
358 Ok(value)
359 }
360
361 fn request_body(
362 &self,
363 upsert_rows: Vec<Map<String, Value>>,
364 deletes: Vec<DocumentId>,
365 ) -> Value {
366 let mut body = Map::new();
367 if let Some(distance_metric) = &self.distance_metric {
368 body.insert(
369 "distance_metric".to_owned(),
370 Value::String(distance_metric.clone()),
371 );
372 }
373 if !upsert_rows.is_empty() {
374 if let Some(disable_backpressure) = self.disable_backpressure {
375 body.insert(
376 "disable_backpressure".to_owned(),
377 Value::Bool(disable_backpressure),
378 );
379 }
380 body.insert("schema".to_owned(), self.schema.clone());
381 body.insert(
382 "upsert_rows".to_owned(),
383 Value::Array(upsert_rows.into_iter().map(Value::Object).collect()),
384 );
385 }
386 if !deletes.is_empty() {
387 body.insert(
388 "deletes".to_owned(),
389 Value::Array(
390 deletes
391 .into_iter()
392 .map(|id| serde_json::to_value(id).expect("serialize document id"))
393 .collect(),
394 ),
395 );
396 }
397 Value::Object(body)
398 }
399}
400
401impl AsyncTruncateSinkWriter for TurbopufferSinkWriter {
402 async fn write_chunk<'a>(
403 &'a mut self,
404 chunk: StreamChunk,
405 _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
406 ) -> Result<()> {
407 if self.is_append_only {
408 let mut batches = BTreeMap::<String, Vec<Map<String, Value>>>::new();
409 for (op, row) in chunk.rows() {
410 match op {
411 Op::Insert | Op::UpdateInsert => {
412 let id = match self.id_for_row(&row) {
413 Ok(id) => id,
414 Err(err) => {
415 tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid document id");
416 continue;
417 }
418 };
419 let upsert_row = match self.upsert_row(&row, id) {
420 Ok(row) => row,
421 Err(err) => {
422 tracing::warn!(error = %err.as_report(), "skip turbopuffer row failed to encode upsert payload");
423 continue;
424 }
425 };
426 let url = match self.url_for_row(&row) {
427 Ok(url) => url,
428 Err(err) => {
429 tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid namespace");
430 continue;
431 }
432 };
433 batches.entry(url).or_default().push(upsert_row);
434 }
435 Op::Delete | Op::UpdateDelete => {
436 return Err(SinkError::Http(anyhow!(
437 "`Delete` or `UpdateDelete` operation is not supported in append-only turbopuffer sink"
438 )));
439 }
440 }
441 }
442
443 for (url, upsert_rows) in batches {
444 if upsert_rows.is_empty() {
445 continue;
446 }
447 let resp = self
448 .client
449 .post(url)
450 .json(&self.request_body(upsert_rows, Vec::new()))
451 .send()
452 .await
453 .context("turbopuffer write request failed")
454 .map_err(SinkError::Http)?;
455
456 if !resp.status().is_success() {
457 let status = resp.status();
458 let body = resp.text().await.unwrap_or_default();
459 return Err(SinkError::Http(anyhow!(
460 "Turbopuffer sink received non-success response: {} {}",
461 status,
462 body
463 )));
464 }
465 }
466 } else {
467 let mut batches = BTreeMap::<String, HashMap<DocumentId, CompactedOp>>::new();
468 for (op, row) in chunk.rows() {
469 let id = match self.id_for_row(&row) {
470 Ok(id) => id,
471 Err(err) => {
472 tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid document id");
473 continue;
474 }
475 };
476 let url = match self.url_for_row(&row) {
477 Ok(url) => url,
478 Err(err) => {
479 tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid namespace");
480 continue;
481 }
482 };
483 match op {
484 Op::Insert | Op::UpdateInsert => {
485 let upsert_row = match self.upsert_row(&row, id.clone()) {
486 Ok(row) => row,
487 Err(err) => {
488 tracing::warn!(error = %err.as_report(), "skip turbopuffer row failed to encode upsert payload");
489 continue;
490 }
491 };
492 batches
493 .entry(url)
494 .or_default()
495 .insert(id, CompactedOp::Upsert(upsert_row));
496 }
497 Op::Delete | Op::UpdateDelete => {
498 batches
499 .entry(url)
500 .or_default()
501 .insert(id, CompactedOp::Delete);
502 }
503 }
504 }
505
506 for (url, batch) in batches {
507 let mut upsert_rows = Vec::new();
508 let mut deletes = Vec::new();
509 for (id, op) in batch {
510 match op {
511 CompactedOp::Upsert(row) => upsert_rows.push(row),
512 CompactedOp::Delete => deletes.push(id),
513 }
514 }
515 if upsert_rows.is_empty() && deletes.is_empty() {
516 continue;
517 }
518 let resp = self
519 .client
520 .post(url)
521 .json(&self.request_body(upsert_rows, deletes))
522 .send()
523 .await
524 .context("turbopuffer write request failed")
525 .map_err(SinkError::Http)?;
526
527 if !resp.status().is_success() {
528 let status = resp.status();
529 let body = resp.text().await.unwrap_or_default();
530 return Err(SinkError::Http(anyhow!(
531 "Turbopuffer sink received non-success response: {} {}",
532 status,
533 body
534 )));
535 }
536 }
537 }
538
539 Ok(())
540 }
541}
542
543#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize)]
544#[serde(untagged)]
545enum DocumentId {
546 U64(u64),
547 String(String),
548}
549
550#[derive(Debug)]
551enum CompactedOp {
552 Upsert(Map<String, Value>),
553 Delete,
554}
555
556fn document_id_from_i64(value: i64) -> DocumentId {
557 if value < 0 {
558 tracing::warn!(
559 value,
560 "cast negative turbopuffer integer document id to unsigned integer"
561 );
562 }
563 DocumentId::U64(value as u64)
564}
565
566fn validate_namespace(namespace: &str) -> Result<()> {
567 if namespace.is_empty() || namespace.len() > 128 {
568 return Err(SinkError::Config(anyhow!(
569 "Turbopuffer namespace must be 1 to 128 bytes"
570 )));
571 }
572 if !namespace
573 .bytes()
574 .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.'))
575 {
576 return Err(SinkError::Config(anyhow!(
577 "Turbopuffer namespace must match [A-Za-z0-9-_.]{{1,128}}"
578 )));
579 }
580 Ok(())
581}
582
583fn parse_column_selection(
584 value: Option<&str>,
585 schema: &Schema,
586 attribute_indices: &[usize],
587) -> Result<HashSet<String>> {
588 let attribute_names = attribute_indices
589 .iter()
590 .map(|index| schema[*index].name.as_str())
591 .collect::<HashSet<_>>();
592 let columns: HashSet<String> = match value {
593 Some(value) if value.trim() == "*" => {
594 return Ok(attribute_names
595 .into_iter()
596 .map(str::to_owned)
597 .collect::<HashSet<_>>());
598 }
599 Some(value) => value
600 .split(',')
601 .map(str::trim)
602 .filter(|name| !name.is_empty())
603 .map(str::to_owned)
604 .collect(),
605 None => return Ok(HashSet::new()),
606 };
607 for column in &columns {
608 if !attribute_names.contains(column.as_str()) {
609 return Err(SinkError::Config(anyhow!(
610 "Turbopuffer schema option references unknown attribute column '{}'",
611 column
612 )));
613 }
614 }
615 Ok(columns)
616}
617
618fn build_turbopuffer_schema(
619 schema: &Schema,
620 attribute_indices: &[usize],
621 full_text_search_columns: &HashSet<String>,
622 filterable_columns: &HashSet<String>,
623) -> Result<Value> {
624 let mut result = Map::new();
625 for index in attribute_indices {
626 let field = &schema[*index];
627 let mut config = Map::new();
628 let data_type = field.data_type();
629 let turbopuffer_type = turbopuffer_type(&data_type)?;
630 let is_vector = matches!(data_type, DataType::Vector(_));
631 let is_full_text_search = full_text_search_columns.contains(&field.name);
632 if is_full_text_search && !supports_full_text_search(&data_type) {
633 return Err(SinkError::Config(anyhow!(
634 "Turbopuffer full_text_search column '{}' must be string or []string",
635 field.name
636 )));
637 }
638 config.insert("type".to_owned(), Value::String(turbopuffer_type));
639 if filterable_columns.contains(&field.name) {
640 config.insert("filterable".to_owned(), Value::Bool(true));
641 }
642 if is_full_text_search {
643 config.insert("full_text_search".to_owned(), Value::Bool(true));
644 }
645 if is_vector {
646 config.insert("ann".to_owned(), Value::Bool(true));
647 }
648 result.insert(field.name.clone(), Value::Object(config));
649 }
650 Ok(Value::Object(result))
651}
652
653fn supports_full_text_search(data_type: &DataType) -> bool {
654 match data_type {
655 DataType::Varchar => true,
656 DataType::List(list_type) => matches!(list_type.elem(), DataType::Varchar),
657 _ => false,
658 }
659}
660
661fn turbopuffer_type(data_type: &DataType) -> Result<String> {
686 match data_type {
687 DataType::Boolean => Ok("bool".to_owned()),
688 DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Serial => {
689 Ok("int".to_owned())
690 }
691 DataType::Float32 | DataType::Float64 | DataType::Decimal => Ok("float".to_owned()),
692 DataType::Varchar => Ok("string".to_owned()),
693 DataType::Date | DataType::Timestamp => Ok("datetime".to_owned()),
694 DataType::List(list_type) => match list_type.elem() {
695 DataType::Boolean => Ok("[]bool".to_owned()),
696 DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Serial => {
697 Ok("[]int".to_owned())
698 }
699 DataType::Float32 | DataType::Float64 | DataType::Decimal => Ok("[]float".to_owned()),
700 DataType::Varchar => Ok("[]string".to_owned()),
701 DataType::Date | DataType::Timestamp => Ok("[]datetime".to_owned()),
702 elem_type => Err(unsupported_type(&format!("list element {:?}", elem_type))),
703 },
704 DataType::Vector(dimension) => Ok(format!("[{}]f32", dimension)),
705 data_type => Err(unsupported_type(&format!("{:?}", data_type))),
706 }
707}
708
709fn unsupported_type(data_type: &str) -> SinkError {
710 SinkError::Config(anyhow!(
711 "Turbopuffer sink does not support column type {}",
712 data_type
713 ))
714}
715
716#[cfg(test)]
717mod tests {
718 #[cfg(not(madsim))]
719 use std::io::{Read, Write};
720 #[cfg(not(madsim))]
721 use std::net::TcpListener;
722 #[cfg(not(madsim))]
723 use std::sync::mpsc;
724 #[cfg(not(madsim))]
725 use std::thread;
726
727 #[cfg(not(madsim))]
728 use risingwave_common::array::StreamChunk;
729 #[cfg(not(madsim))]
730 use risingwave_common::array::stream_chunk::StreamChunkTestExt as _;
731 use risingwave_common::array::{ListValue, VectorVal};
732 use risingwave_common::catalog::Field;
733 use risingwave_common::row::OwnedRow;
734 use risingwave_common::types::{ListType, ScalarImpl, Timestamp};
735 use serde_json::json;
736
737 use super::*;
738 #[cfg(not(madsim))]
739 use crate::sink::log_store::DeliveryFutureManager;
740
741 #[test]
742 fn test_build_schema_flags() {
743 let schema = Schema::new(vec![
744 Field::with_name(DataType::Varchar, "id"),
745 Field::with_name(DataType::Varchar, "body"),
746 Field::with_name(DataType::List(ListType::new(DataType::Varchar)), "tags"),
747 Field::with_name(DataType::Boolean, "flag"),
748 Field::with_name(DataType::Vector(384), "vector"),
749 ]);
750 let generated = build_turbopuffer_schema(
751 &schema,
752 &[1, 2, 3, 4],
753 &parse_column_selection(Some("body,tags"), &schema, &[1, 2, 3, 4]).unwrap(),
754 &parse_column_selection(Some("*"), &schema, &[1, 2, 3, 4]).unwrap(),
755 )
756 .unwrap();
757
758 assert_eq!(generated["body"]["type"], json!("string"));
759 assert_eq!(generated["body"]["filterable"], json!(true));
760 assert_eq!(generated["body"]["full_text_search"], json!(true));
761 assert_eq!(generated["tags"]["type"], json!("[]string"));
762 assert_eq!(generated["tags"]["full_text_search"], json!(true));
763 assert_eq!(generated["vector"]["type"], json!("[384]f32"));
764 assert_eq!(generated["vector"]["ann"], json!(true));
765 }
766
767 #[cfg(not(madsim))]
768 #[tokio::test]
769 async fn test_write_chunk_posts_batched_payload_and_headers() {
770 let (base_url, request_rx, server_thread) = spawn_mock_http_server();
771 let schema = Schema::new(vec![
772 Field::with_name(DataType::Varchar, "id"),
773 Field::with_name(DataType::Varchar, "body"),
774 Field::with_name(DataType::Varchar, "workspace_id"),
775 ]);
776 let payload_indices = vec![1];
777 let generated_schema = build_turbopuffer_schema(
778 &schema,
779 &payload_indices,
780 &parse_column_selection(Some("body"), &schema, &payload_indices).unwrap(),
781 &parse_column_selection(Some("*"), &schema, &payload_indices).unwrap(),
782 )
783 .unwrap();
784 let config = TurbopufferConfig {
785 base_url,
786 namespace: None,
787 namespace_column: Some("workspace_id".to_owned()),
788 api_key: "tpuf_test_key".to_owned(),
789 distance_metric: None,
790 disable_backpressure: Some(true),
791 full_text_search_columns: Some("body".to_owned()),
792 filterable_columns: Some("*".to_owned()),
793 r#type: "upsert".to_owned(),
794 };
795 let mut writer = TurbopufferSinkWriter::new(
796 config,
797 schema,
798 false,
799 0,
800 TurbopufferNamespace::Dynamic { index: 2 },
801 payload_indices,
802 generated_schema,
803 )
804 .unwrap();
805 let chunk = StreamChunk::from_pretty(
806 "T T T
807 U- old-id old_body ns_1
808 U+ new-id new_body ns_1",
809 );
810 let mut future_manager = DeliveryFutureManager::new(0);
811
812 writer
813 .write_chunk(chunk, future_manager.start_write_chunk(0, 0))
814 .await
815 .unwrap();
816 let request = request_rx.recv().unwrap();
817 server_thread.join().unwrap();
818
819 assert!(request.starts_with("post /v2/namespaces/ns_1 http/1.1"));
820 assert!(request.contains("authorization: bearer tpuf_test_key"));
821 assert!(request.contains("content-type: application/json"));
822
823 let body = request.split("\r\n\r\n").nth(1).unwrap();
824 let body: Value = serde_json::from_str(body).unwrap();
825 assert_eq!(body["disable_backpressure"], json!(true));
826 assert_eq!(body["schema"]["body"]["type"], json!("string"));
827 assert_eq!(body["schema"]["body"]["filterable"], json!(true));
828 assert_eq!(body["schema"]["body"]["full_text_search"], json!(true));
829 assert_eq!(body["deletes"], json!(["old-id"]));
830 assert_eq!(body["upsert_rows"].as_array().unwrap().len(), 1);
831 assert_eq!(body["upsert_rows"][0]["id"], json!("new-id"));
832 assert_eq!(body["upsert_rows"][0]["body"], json!("new_body"));
833 assert!(body.get("distance_metric").is_none());
834 }
835
836 #[test]
837 fn test_decimal_and_serial_schema_types() {
838 assert_eq!(turbopuffer_type(&DataType::Decimal).unwrap(), "float");
839 assert_eq!(turbopuffer_type(&DataType::Serial).unwrap(), "int");
840 assert_eq!(
841 turbopuffer_type(&DataType::List(ListType::new(DataType::Decimal))).unwrap(),
842 "[]float"
843 );
844 assert_eq!(
845 turbopuffer_type(&DataType::List(ListType::new(DataType::Serial))).unwrap(),
846 "[]int"
847 );
848 }
849
850 #[test]
851 fn test_manual_http_sink_schema_and_payload_shape() {
852 let schema = Schema::new(vec![
853 Field::with_name(DataType::Varchar, "id"),
854 Field::with_name(DataType::Varchar, "workspaceId"),
855 Field::with_name(DataType::Varchar, "inboxFeedItemId"),
856 Field::with_name(DataType::Varchar, "body"),
857 Field::with_name(
858 DataType::List(ListType::new(DataType::Varchar)),
859 "noteContents",
860 ),
861 Field::with_name(DataType::Varchar, "communityMemberHandle"),
862 Field::with_name(DataType::Varchar, "communityMemberIdentifier"),
863 Field::with_name(DataType::Varchar, "inboxFeedItemTitle"),
864 Field::with_name(DataType::Boolean, "isInboxFeedItemStarred"),
865 Field::with_name(DataType::Boolean, "isInboxFeedItemAnswered"),
866 Field::with_name(DataType::Timestamp, "inboxFeedItemPreviewTimestamp"),
867 Field::with_name(DataType::Timestamp, "inboxFeedItemPublishTimestamp"),
868 Field::with_name(DataType::Int64, "inboxFeedItemAuthorInstagramFollowerCount"),
869 Field::with_name(DataType::Int64, "inboxFeedItemAuthorTikTokFollowerCount"),
870 Field::with_name(
871 DataType::List(ListType::new(DataType::Varchar)),
872 "attributes",
873 ),
874 Field::with_name(DataType::Varchar, "threadId"),
875 Field::with_name(DataType::Vector(384), "vector"),
876 ]);
877 let attribute_indices = (2..schema.len()).collect_vec();
878 let full_text_search_columns = parse_column_selection(
879 Some(
880 "body,noteContents,communityMemberHandle,communityMemberIdentifier,inboxFeedItemTitle",
881 ),
882 &schema,
883 &attribute_indices,
884 )
885 .unwrap();
886 let filterable_columns =
887 parse_column_selection(Some("*"), &schema, &attribute_indices).unwrap();
888 let generated_schema = build_turbopuffer_schema(
889 &schema,
890 &attribute_indices,
891 &full_text_search_columns,
892 &filterable_columns,
893 )
894 .unwrap();
895
896 assert_eq!(
897 generated_schema,
898 json!({
899 "inboxFeedItemId": {"type": "string", "filterable": true},
900 "body": {"type": "string", "filterable": true, "full_text_search": true},
901 "noteContents": {"type": "[]string", "filterable": true, "full_text_search": true},
902 "communityMemberHandle": {"type": "string", "filterable": true, "full_text_search": true},
903 "communityMemberIdentifier": {"type": "string", "filterable": true, "full_text_search": true},
904 "inboxFeedItemTitle": {"type": "string", "filterable": true, "full_text_search": true},
905 "isInboxFeedItemStarred": {"type": "bool", "filterable": true},
906 "isInboxFeedItemAnswered": {"type": "bool", "filterable": true},
907 "inboxFeedItemPreviewTimestamp": {"type": "datetime", "filterable": true},
908 "inboxFeedItemPublishTimestamp": {"type": "datetime", "filterable": true},
909 "inboxFeedItemAuthorInstagramFollowerCount": {"type": "int", "filterable": true},
910 "inboxFeedItemAuthorTikTokFollowerCount": {"type": "int", "filterable": true},
911 "attributes": {"type": "[]string", "filterable": true},
912 "threadId": {"type": "string", "filterable": true},
913 "vector": {"type": "[384]f32", "filterable": true, "ann": true}
914 })
915 );
916
917 let config = TurbopufferConfig {
918 base_url: "http://127.0.0.1:0".to_owned(),
919 namespace: None,
920 namespace_column: Some("workspaceId".to_owned()),
921 api_key: "tpuf_test_key".to_owned(),
922 distance_metric: Some("cosine_distance".to_owned()),
923 disable_backpressure: Some(true),
924 full_text_search_columns: Some("body,noteContents,communityMemberHandle,communityMemberIdentifier,inboxFeedItemTitle".to_owned()),
925 filterable_columns: Some("*".to_owned()),
926 r#type: "upsert".to_owned(),
927 };
928 let writer = TurbopufferSinkWriter::new(
929 config,
930 schema,
931 false,
932 0,
933 TurbopufferNamespace::Dynamic { index: 1 },
934 attribute_indices,
935 generated_schema.clone(),
936 )
937 .unwrap();
938 let vector =
939 VectorVal::from_text(&format!("[{}]", vec!["0.25"; 384].join(",")), 384).unwrap();
940 let row = OwnedRow::new(vec![
941 Some(ScalarImpl::Utf8("doc-1".into())),
942 Some(ScalarImpl::Utf8("workspace-1".into())),
943 Some(ScalarImpl::Utf8("item-1".into())),
944 Some(ScalarImpl::Utf8("body text".into())),
945 Some(ScalarImpl::List(ListValue::from_iter(["note a", "note b"]))),
946 Some(ScalarImpl::Utf8("@member".into())),
947 Some(ScalarImpl::Utf8("member-1".into())),
948 Some(ScalarImpl::Utf8("title".into())),
949 Some(ScalarImpl::Bool(true)),
950 Some(ScalarImpl::Bool(false)),
951 Some(ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(
952 1_781_582_706,
953 0,
954 ))),
955 Some(ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(
956 1_781_598_707,
957 0,
958 ))),
959 Some(ScalarImpl::Int64(12345)),
960 Some(ScalarImpl::Int64(67890)),
961 Some(ScalarImpl::List(ListValue::from_iter(["important", "vip"]))),
962 Some(ScalarImpl::Utf8("thread-1".into())),
963 Some(ScalarImpl::Vector(vector)),
964 ]);
965 let id = writer.id_for_row(&row).unwrap();
966 let upsert_row = writer.upsert_row(&row, id).unwrap();
967 let body = writer.request_body(vec![upsert_row], Vec::new());
968
969 assert_eq!(body["distance_metric"], json!("cosine_distance"));
970 assert_eq!(body["disable_backpressure"], json!(true));
971 assert_eq!(body["schema"], generated_schema);
972 assert_eq!(body["upsert_rows"][0]["id"], json!("doc-1"));
973 assert_eq!(body["upsert_rows"][0]["body"], json!("body text"));
974 assert_eq!(
975 body["upsert_rows"][0]["noteContents"],
976 json!(["note a", "note b"])
977 );
978 assert_eq!(
979 body["upsert_rows"][0]["isInboxFeedItemStarred"],
980 json!(true)
981 );
982 assert_eq!(
983 body["upsert_rows"][0]["isInboxFeedItemAnswered"],
984 json!(false)
985 );
986 assert_eq!(
987 body["upsert_rows"][0]["inboxFeedItemAuthorInstagramFollowerCount"],
988 json!(12345)
989 );
990 assert_eq!(
991 body["upsert_rows"][0]["inboxFeedItemPreviewTimestamp"],
992 json!("2026-06-16 04:05:06.000000")
993 );
994 assert_eq!(
995 body["upsert_rows"][0]["vector"].as_array().unwrap().len(),
996 384
997 );
998 assert_eq!(body["upsert_rows"][0]["vector"][0], json!(0.25));
999 }
1000
1001 #[cfg(not(madsim))]
1002 fn spawn_mock_http_server() -> (String, mpsc::Receiver<String>, thread::JoinHandle<()>) {
1003 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1004 let addr = listener.local_addr().unwrap();
1005 let (request_tx, request_rx) = mpsc::channel();
1006 let server_thread = thread::spawn(move || {
1007 let (mut stream, _) = listener.accept().unwrap();
1008 let mut buf = Vec::new();
1009 let header_end = loop {
1010 let mut tmp = [0; 1024];
1011 let read = stream.read(&mut tmp).unwrap();
1012 assert_ne!(read, 0);
1013 buf.extend_from_slice(&tmp[..read]);
1014 if let Some(header_end) = find_header_end(&buf) {
1015 break header_end;
1016 }
1017 };
1018 let headers = String::from_utf8_lossy(&buf[..header_end]);
1019 let content_length = headers
1020 .lines()
1021 .find_map(|line| {
1022 let (name, value) = line.split_once(':')?;
1023 name.eq_ignore_ascii_case("content-length")
1024 .then(|| value.trim().parse::<usize>().unwrap())
1025 })
1026 .unwrap();
1027 while buf.len() < header_end + 4 + content_length {
1028 let mut tmp = [0; 1024];
1029 let read = stream.read(&mut tmp).unwrap();
1030 assert_ne!(read, 0);
1031 buf.extend_from_slice(&tmp[..read]);
1032 }
1033 let request = String::from_utf8(buf[..header_end + 4 + content_length].to_vec())
1034 .expect("HTTP request should be utf8");
1035 request_tx.send(request.to_lowercase()).unwrap();
1036 stream
1037 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n")
1038 .unwrap();
1039 });
1040 (format!("http://{}", addr), request_rx, server_thread)
1041 }
1042
1043 #[cfg(not(madsim))]
1044 fn find_header_end(buf: &[u8]) -> Option<usize> {
1045 buf.windows(4).position(|window| window == b"\r\n\r\n")
1046 }
1047}