Skip to main content

risingwave_connector/sink/
turbopuffer.rs

1// Copyright 2026 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16
17use anyhow::{Context, anyhow};
18use itertools::Itertools;
19use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
20use risingwave_common::array::{Op, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, ScalarRefImpl};
24use serde::{Deserialize, Serialize};
25use serde_json::{Map, Value};
26use serde_with::{DisplayFromStr, serde_as};
27use thiserror_ext::AsReport;
28use with_options::WithOptions;
29
30use crate::enforce_secret::EnforceSecret;
31use crate::sink::encoder::{JsonEncoder, RowEncoder};
32use crate::sink::log_store::DeliveryFutureManagerAddFuture;
33use crate::sink::writer::{
34    AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter, AsyncTruncateSinkWriterExt,
35};
36use crate::sink::{Result, Sink, SinkError, SinkParam, SinkWriterParam};
37
38pub const TURBOPUFFER_SINK: &str = "turbopuffer";
39
40#[serde_as]
41#[derive(Clone, Debug, Deserialize, WithOptions)]
42pub struct TurbopufferConfig {
43    pub base_url: String,
44    pub namespace: Option<String>,
45    pub namespace_column: Option<String>,
46    pub api_key: String,
47    pub distance_metric: Option<String>,
48    #[serde_as(as = "Option<DisplayFromStr>")]
49    pub disable_backpressure: Option<bool>,
50    pub full_text_search_columns: Option<String>,
51    pub filterable_columns: Option<String>,
52    pub r#type: String, // accept "append-only" or "upsert"
53}
54
55impl EnforceSecret for TurbopufferConfig {
56    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
57        "api_key",
58    };
59}
60
61impl TurbopufferConfig {
62    fn from_btreemap(values: BTreeMap<String, String>) -> Result<Self> {
63        serde_json::from_value::<TurbopufferConfig>(
64            serde_json::to_value(values).expect("serialize sink properties"),
65        )
66        .map_err(|e| SinkError::Config(anyhow!(e)))
67    }
68}
69
70#[derive(Clone, Debug)]
71enum TurbopufferNamespace {
72    Static(String),
73    Dynamic { index: usize },
74}
75
76#[derive(Clone, Debug)]
77pub struct TurbopufferSink {
78    config: TurbopufferConfig,
79    schema: Schema,
80    is_append_only: bool,
81    pk_index: usize,
82    namespace: TurbopufferNamespace,
83    attribute_indices: Vec<usize>,
84    generated_schema: Value,
85}
86
87impl EnforceSecret for TurbopufferSink {
88    fn enforce_secret<'a>(
89        prop_iter: impl Iterator<Item = &'a str>,
90    ) -> crate::error::ConnectorResult<()> {
91        for prop in prop_iter {
92            TurbopufferConfig::enforce_one(prop)?;
93        }
94        Ok(())
95    }
96}
97
98impl TryFrom<SinkParam> for TurbopufferSink {
99    type Error = SinkError;
100
101    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
102        let schema = param.schema();
103        let is_append_only = param.sink_type.is_append_only();
104        let pk_indices = param.downstream_pk_or_empty();
105        let [pk_index] = pk_indices.as_slice() else {
106            return Err(SinkError::Config(anyhow!(
107                "Turbopuffer sink requires exactly one primary_key column"
108            )));
109        };
110        let pk_index = *pk_index;
111        match schema[pk_index].data_type() {
112            DataType::Int16
113            | DataType::Int32
114            | DataType::Int64
115            | DataType::Serial
116            | DataType::Varchar => {}
117            data_type => {
118                return Err(SinkError::Config(anyhow!(
119                    "Turbopuffer document id column must be an integer or varchar, got {:?}",
120                    data_type
121                )));
122            }
123        };
124        let config = TurbopufferConfig::from_btreemap(param.properties)?;
125
126        let namespace = match (&config.namespace, &config.namespace_column) {
127            (Some(namespace), None) => {
128                validate_namespace(namespace)?;
129                TurbopufferNamespace::Static(namespace.clone())
130            }
131            (None, Some(namespace_column)) => {
132                let index = schema
133                    .fields()
134                    .iter()
135                    .position(|field| field.name == *namespace_column)
136                    .ok_or_else(|| {
137                        SinkError::Config(anyhow!(
138                            "Turbopuffer namespace_column '{}' not found in sink schema",
139                            namespace_column
140                        ))
141                    })?;
142                if schema[index].data_type != DataType::Varchar {
143                    return Err(SinkError::Config(anyhow!(
144                        "Turbopuffer namespace_column must be varchar, got {:?}",
145                        schema[index].data_type
146                    )));
147                }
148                TurbopufferNamespace::Dynamic { index }
149            }
150            (Some(_), Some(_)) => {
151                return Err(SinkError::Config(anyhow!(
152                    "Turbopuffer sink requires only one of namespace or namespace_column"
153                )));
154            }
155            (None, None) => {
156                return Err(SinkError::Config(anyhow!(
157                    "Turbopuffer sink requires either namespace or namespace_column"
158                )));
159            }
160        };
161
162        // Turbopuffer treats `id` as the document ID in write requests; it is not a schema
163        // attribute. Dynamic namespace is also metadata for routing, not a document attribute.
164        let excluded_indices = match &namespace {
165            TurbopufferNamespace::Static(_) => HashSet::from([pk_index]),
166            TurbopufferNamespace::Dynamic { index } => HashSet::from([pk_index, *index]),
167        };
168        let attribute_indices = (0..schema.len())
169            .filter(|idx| !excluded_indices.contains(idx))
170            .collect_vec();
171        for index in &attribute_indices {
172            if schema[*index].name == "id" {
173                return Err(SinkError::Config(anyhow!(
174                    "Turbopuffer attribute column must not be named id"
175                )));
176            }
177        }
178        let full_text_search_columns = parse_column_selection(
179            config.full_text_search_columns.as_deref(),
180            &schema,
181            &attribute_indices,
182        )?;
183        let filterable_columns = parse_column_selection(
184            config.filterable_columns.as_deref(),
185            &schema,
186            &attribute_indices,
187        )?;
188        let has_vector = attribute_indices
189            .iter()
190            .any(|idx| matches!(schema[*idx].data_type, DataType::Vector(_)));
191        if has_vector && config.distance_metric.is_none() {
192            return Err(SinkError::Config(anyhow!(
193                "Turbopuffer sink requires distance_metric when sink schema contains vector columns"
194            )));
195        }
196        // This validates every document attribute type before the writer is created:
197        // `build_turbopuffer_schema` calls `turbopuffer_type` for each attribute and
198        // returns a config error for unsupported types.
199        let generated_schema = build_turbopuffer_schema(
200            &schema,
201            &attribute_indices,
202            &full_text_search_columns,
203            &filterable_columns,
204        )?;
205
206        Ok(Self {
207            config,
208            schema,
209            is_append_only,
210            pk_index,
211            namespace,
212            attribute_indices,
213            generated_schema,
214        })
215    }
216}
217
218impl Sink for TurbopufferSink {
219    type LogSinker = AsyncTruncateLogSinkerOf<TurbopufferSinkWriter>;
220
221    const SINK_NAME: &'static str = TURBOPUFFER_SINK;
222
223    async fn validate(&self) -> Result<()> {
224        Ok(())
225    }
226
227    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
228        Ok(TurbopufferSinkWriter::new(
229            self.config.clone(),
230            self.schema.clone(),
231            self.is_append_only,
232            self.pk_index,
233            self.namespace.clone(),
234            self.attribute_indices.clone(),
235            self.generated_schema.clone(),
236        )?
237        .into_log_sinker(usize::MAX))
238    }
239}
240
241pub struct TurbopufferSinkWriter {
242    client: reqwest::Client,
243    base_url: String,
244    distance_metric: Option<String>,
245    disable_backpressure: Option<bool>,
246    schema: Value,
247    is_append_only: bool,
248    pk_index: usize,
249    namespace: TurbopufferNamespace,
250    row_encoder: JsonEncoder,
251}
252
253impl TurbopufferSinkWriter {
254    fn new(
255        config: TurbopufferConfig,
256        schema: Schema,
257        is_append_only: bool,
258        pk_index: usize,
259        namespace: TurbopufferNamespace,
260        attribute_indices: Vec<usize>,
261        generated_schema: Value,
262    ) -> Result<Self> {
263        let mut header_map = HeaderMap::new();
264        header_map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
265        let authorization = format!("Bearer {}", config.api_key);
266        header_map.insert(
267            AUTHORIZATION,
268            authorization
269                .parse()
270                .context("invalid turbopuffer api_key")
271                .map_err(SinkError::Config)?,
272        );
273        let client = reqwest::Client::builder()
274            .default_headers(header_map)
275            .build()
276            .context("failed to build turbopuffer HTTP client")
277            .map_err(SinkError::Http)?;
278        let base_url = config
279            .base_url
280            .parse::<reqwest::Url>()
281            .context("invalid turbopuffer base_url")
282            .map_err(SinkError::Config)?
283            .to_string()
284            .trim_end_matches('/')
285            .to_owned();
286        let row_encoder = JsonEncoder::new_with_turbopuffer(schema, Some(attribute_indices));
287        Ok(Self {
288            client,
289            base_url,
290            distance_metric: config.distance_metric,
291            disable_backpressure: config.disable_backpressure,
292            schema: generated_schema,
293            is_append_only,
294            pk_index,
295            namespace,
296            row_encoder,
297        })
298    }
299
300    fn url_for_row(&self, row: &impl Row) -> Result<String> {
301        match &self.namespace {
302            TurbopufferNamespace::Static(namespace) => Ok(format!(
303                "{}/v2/namespaces/{}",
304                self.base_url,
305                namespace.as_str()
306            )),
307            TurbopufferNamespace::Dynamic { index } => {
308                let namespace = match row.datum_at(*index) {
309                    Some(ScalarRefImpl::Utf8(namespace)) => namespace,
310                    None => {
311                        return Err(SinkError::Http(anyhow!(
312                            "Turbopuffer namespace_column cannot be null"
313                        )));
314                    }
315                    Some(_) => {
316                        return Err(SinkError::Http(anyhow!(
317                            "unexpected namespace_column type, expected varchar"
318                        )));
319                    }
320                };
321                validate_namespace(namespace)?;
322                Ok(format!("{}/v2/namespaces/{}", self.base_url, namespace))
323            }
324        }
325    }
326
327    // Turbopuffer document IDs are unsigned 64-bit integers, UUIDs, or strings up to 64 bytes.
328    // RisingWave UUID IDs can be represented with varchar.
329    fn id_for_row(&self, row: &impl Row) -> Result<DocumentId> {
330        let datum = row.datum_at(self.pk_index).ok_or_else(|| {
331            SinkError::Http(anyhow!("Turbopuffer document id column cannot be null"))
332        })?;
333        match datum {
334            ScalarRefImpl::Int16(value) => Ok(document_id_from_i64(value as i64)),
335            ScalarRefImpl::Int32(value) => Ok(document_id_from_i64(value as i64)),
336            ScalarRefImpl::Int64(value) => Ok(document_id_from_i64(value)),
337            ScalarRefImpl::Serial(value) => Ok(document_id_from_i64(value.into_inner())),
338            ScalarRefImpl::Utf8(value) => {
339                if value.len() > 64 {
340                    return Err(SinkError::Http(anyhow!(
341                        "Turbopuffer string document id exceeds 64 bytes"
342                    )));
343                }
344                Ok(DocumentId::String(value.to_owned()))
345            }
346            _ => Err(SinkError::Http(anyhow!(
347                "Turbopuffer document id column must be an integer or varchar"
348            ))),
349        }
350    }
351
352    fn upsert_row(&self, row: &impl Row, id: DocumentId) -> Result<Map<String, Value>> {
353        let mut value = self.row_encoder.encode(row)?;
354        value.insert(
355            "id".to_owned(),
356            serde_json::to_value(id).expect("serialize document id"),
357        );
358        Ok(value)
359    }
360
361    fn request_body(
362        &self,
363        upsert_rows: Vec<Map<String, Value>>,
364        deletes: Vec<DocumentId>,
365    ) -> Value {
366        let mut body = Map::new();
367        if let Some(distance_metric) = &self.distance_metric {
368            body.insert(
369                "distance_metric".to_owned(),
370                Value::String(distance_metric.clone()),
371            );
372        }
373        if !upsert_rows.is_empty() {
374            if let Some(disable_backpressure) = self.disable_backpressure {
375                body.insert(
376                    "disable_backpressure".to_owned(),
377                    Value::Bool(disable_backpressure),
378                );
379            }
380            body.insert("schema".to_owned(), self.schema.clone());
381            body.insert(
382                "upsert_rows".to_owned(),
383                Value::Array(upsert_rows.into_iter().map(Value::Object).collect()),
384            );
385        }
386        if !deletes.is_empty() {
387            body.insert(
388                "deletes".to_owned(),
389                Value::Array(
390                    deletes
391                        .into_iter()
392                        .map(|id| serde_json::to_value(id).expect("serialize document id"))
393                        .collect(),
394                ),
395            );
396        }
397        Value::Object(body)
398    }
399}
400
401impl AsyncTruncateSinkWriter for TurbopufferSinkWriter {
402    async fn write_chunk<'a>(
403        &'a mut self,
404        chunk: StreamChunk,
405        _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
406    ) -> Result<()> {
407        if self.is_append_only {
408            let mut batches = BTreeMap::<String, Vec<Map<String, Value>>>::new();
409            for (op, row) in chunk.rows() {
410                match op {
411                    Op::Insert | Op::UpdateInsert => {
412                        let id = match self.id_for_row(&row) {
413                            Ok(id) => id,
414                            Err(err) => {
415                                tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid document id");
416                                continue;
417                            }
418                        };
419                        let upsert_row = match self.upsert_row(&row, id) {
420                            Ok(row) => row,
421                            Err(err) => {
422                                tracing::warn!(error = %err.as_report(), "skip turbopuffer row failed to encode upsert payload");
423                                continue;
424                            }
425                        };
426                        let url = match self.url_for_row(&row) {
427                            Ok(url) => url,
428                            Err(err) => {
429                                tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid namespace");
430                                continue;
431                            }
432                        };
433                        batches.entry(url).or_default().push(upsert_row);
434                    }
435                    Op::Delete | Op::UpdateDelete => {
436                        return Err(SinkError::Http(anyhow!(
437                            "`Delete` or `UpdateDelete` operation is not supported in append-only turbopuffer sink"
438                        )));
439                    }
440                }
441            }
442
443            for (url, upsert_rows) in batches {
444                if upsert_rows.is_empty() {
445                    continue;
446                }
447                let resp = self
448                    .client
449                    .post(url)
450                    .json(&self.request_body(upsert_rows, Vec::new()))
451                    .send()
452                    .await
453                    .context("turbopuffer write request failed")
454                    .map_err(SinkError::Http)?;
455
456                if !resp.status().is_success() {
457                    let status = resp.status();
458                    let body = resp.text().await.unwrap_or_default();
459                    return Err(SinkError::Http(anyhow!(
460                        "Turbopuffer sink received non-success response: {} {}",
461                        status,
462                        body
463                    )));
464                }
465            }
466        } else {
467            let mut batches = BTreeMap::<String, HashMap<DocumentId, CompactedOp>>::new();
468            for (op, row) in chunk.rows() {
469                let id = match self.id_for_row(&row) {
470                    Ok(id) => id,
471                    Err(err) => {
472                        tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid document id");
473                        continue;
474                    }
475                };
476                let url = match self.url_for_row(&row) {
477                    Ok(url) => url,
478                    Err(err) => {
479                        tracing::warn!(error = %err.as_report(), "skip turbopuffer row with invalid namespace");
480                        continue;
481                    }
482                };
483                match op {
484                    Op::Insert | Op::UpdateInsert => {
485                        let upsert_row = match self.upsert_row(&row, id.clone()) {
486                            Ok(row) => row,
487                            Err(err) => {
488                                tracing::warn!(error = %err.as_report(), "skip turbopuffer row failed to encode upsert payload");
489                                continue;
490                            }
491                        };
492                        batches
493                            .entry(url)
494                            .or_default()
495                            .insert(id, CompactedOp::Upsert(upsert_row));
496                    }
497                    Op::Delete | Op::UpdateDelete => {
498                        batches
499                            .entry(url)
500                            .or_default()
501                            .insert(id, CompactedOp::Delete);
502                    }
503                }
504            }
505
506            for (url, batch) in batches {
507                let mut upsert_rows = Vec::new();
508                let mut deletes = Vec::new();
509                for (id, op) in batch {
510                    match op {
511                        CompactedOp::Upsert(row) => upsert_rows.push(row),
512                        CompactedOp::Delete => deletes.push(id),
513                    }
514                }
515                if upsert_rows.is_empty() && deletes.is_empty() {
516                    continue;
517                }
518                let resp = self
519                    .client
520                    .post(url)
521                    .json(&self.request_body(upsert_rows, deletes))
522                    .send()
523                    .await
524                    .context("turbopuffer write request failed")
525                    .map_err(SinkError::Http)?;
526
527                if !resp.status().is_success() {
528                    let status = resp.status();
529                    let body = resp.text().await.unwrap_or_default();
530                    return Err(SinkError::Http(anyhow!(
531                        "Turbopuffer sink received non-success response: {} {}",
532                        status,
533                        body
534                    )));
535                }
536            }
537        }
538
539        Ok(())
540    }
541}
542
543#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize)]
544#[serde(untagged)]
545enum DocumentId {
546    U64(u64),
547    String(String),
548}
549
550#[derive(Debug)]
551enum CompactedOp {
552    Upsert(Map<String, Value>),
553    Delete,
554}
555
556fn document_id_from_i64(value: i64) -> DocumentId {
557    if value < 0 {
558        tracing::warn!(
559            value,
560            "cast negative turbopuffer integer document id to unsigned integer"
561        );
562    }
563    DocumentId::U64(value as u64)
564}
565
566fn validate_namespace(namespace: &str) -> Result<()> {
567    if namespace.is_empty() || namespace.len() > 128 {
568        return Err(SinkError::Config(anyhow!(
569            "Turbopuffer namespace must be 1 to 128 bytes"
570        )));
571    }
572    if !namespace
573        .bytes()
574        .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.'))
575    {
576        return Err(SinkError::Config(anyhow!(
577            "Turbopuffer namespace must match [A-Za-z0-9-_.]{{1,128}}"
578        )));
579    }
580    Ok(())
581}
582
583fn parse_column_selection(
584    value: Option<&str>,
585    schema: &Schema,
586    attribute_indices: &[usize],
587) -> Result<HashSet<String>> {
588    let attribute_names = attribute_indices
589        .iter()
590        .map(|index| schema[*index].name.as_str())
591        .collect::<HashSet<_>>();
592    let columns: HashSet<String> = match value {
593        Some(value) if value.trim() == "*" => {
594            return Ok(attribute_names
595                .into_iter()
596                .map(str::to_owned)
597                .collect::<HashSet<_>>());
598        }
599        Some(value) => value
600            .split(',')
601            .map(str::trim)
602            .filter(|name| !name.is_empty())
603            .map(str::to_owned)
604            .collect(),
605        None => return Ok(HashSet::new()),
606    };
607    for column in &columns {
608        if !attribute_names.contains(column.as_str()) {
609            return Err(SinkError::Config(anyhow!(
610                "Turbopuffer schema option references unknown attribute column '{}'",
611                column
612            )));
613        }
614    }
615    Ok(columns)
616}
617
618fn build_turbopuffer_schema(
619    schema: &Schema,
620    attribute_indices: &[usize],
621    full_text_search_columns: &HashSet<String>,
622    filterable_columns: &HashSet<String>,
623) -> Result<Value> {
624    let mut result = Map::new();
625    for index in attribute_indices {
626        let field = &schema[*index];
627        let mut config = Map::new();
628        let data_type = field.data_type();
629        let turbopuffer_type = turbopuffer_type(&data_type)?;
630        let is_vector = matches!(data_type, DataType::Vector(_));
631        let is_full_text_search = full_text_search_columns.contains(&field.name);
632        if is_full_text_search && !supports_full_text_search(&data_type) {
633            return Err(SinkError::Config(anyhow!(
634                "Turbopuffer full_text_search column '{}' must be string or []string",
635                field.name
636            )));
637        }
638        config.insert("type".to_owned(), Value::String(turbopuffer_type));
639        if filterable_columns.contains(&field.name) {
640            config.insert("filterable".to_owned(), Value::Bool(true));
641        }
642        if is_full_text_search {
643            config.insert("full_text_search".to_owned(), Value::Bool(true));
644        }
645        if is_vector {
646            config.insert("ann".to_owned(), Value::Bool(true));
647        }
648        result.insert(field.name.clone(), Value::Object(config));
649    }
650    Ok(Value::Object(result))
651}
652
653fn supports_full_text_search(data_type: &DataType) -> bool {
654    match data_type {
655        DataType::Varchar => true,
656        DataType::List(list_type) => matches!(list_type.elem(), DataType::Varchar),
657        _ => false,
658    }
659}
660
661// Mapping from RisingWave attribute types to generated turbopuffer schema types and
662// the JSON value shapes emitted by `JsonEncoder`:
663//
664// | RisingWave type                  | turbopuffer type | JSON payload                  |
665// |----------------------------------|------------------|-------------------------------|
666// | boolean                          | bool             | boolean                       |
667// | int16, int32, int64              | int              | number                        |
668// | float32, float64                 | float            | number                        |
669// | varchar                          | string           | string                        |
670// | date                             | datetime         | string: YYYY-MM-DD            |
671// | timestamp                        | datetime         | string: YYYY-MM-DD HH:MM:SS   |
672// | boolean[]                        | []bool           | array of booleans             |
673// | int16[], int32[], int64[]        | []int            | array of numbers              |
674// | float32[], float64[]             | []float          | array of numbers              |
675// | varchar[]                        | []string         | array of strings              |
676// | date[], timestamp[]              | []datetime       | array of datetime strings     |
677// | vector(N)                        | [N]f32           | array of numbers              |
678// | serial                           | int              | number                        |
679// | decimal                          | float            | number, converted through f64 |
680// | serial[]                         | []int            | array of numbers              |
681// | decimal[]                        | []float          | array of f64-converted numbers|
682//
683// The primary key column is encoded separately as the turbopuffer document id, so
684// it does not participate in this schema mapping.
685fn turbopuffer_type(data_type: &DataType) -> Result<String> {
686    match data_type {
687        DataType::Boolean => Ok("bool".to_owned()),
688        DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Serial => {
689            Ok("int".to_owned())
690        }
691        DataType::Float32 | DataType::Float64 | DataType::Decimal => Ok("float".to_owned()),
692        DataType::Varchar => Ok("string".to_owned()),
693        DataType::Date | DataType::Timestamp => Ok("datetime".to_owned()),
694        DataType::List(list_type) => match list_type.elem() {
695            DataType::Boolean => Ok("[]bool".to_owned()),
696            DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Serial => {
697                Ok("[]int".to_owned())
698            }
699            DataType::Float32 | DataType::Float64 | DataType::Decimal => Ok("[]float".to_owned()),
700            DataType::Varchar => Ok("[]string".to_owned()),
701            DataType::Date | DataType::Timestamp => Ok("[]datetime".to_owned()),
702            elem_type => Err(unsupported_type(&format!("list element {:?}", elem_type))),
703        },
704        DataType::Vector(dimension) => Ok(format!("[{}]f32", dimension)),
705        data_type => Err(unsupported_type(&format!("{:?}", data_type))),
706    }
707}
708
709fn unsupported_type(data_type: &str) -> SinkError {
710    SinkError::Config(anyhow!(
711        "Turbopuffer sink does not support column type {}",
712        data_type
713    ))
714}
715
716#[cfg(test)]
717mod tests {
718    #[cfg(not(madsim))]
719    use std::io::{Read, Write};
720    #[cfg(not(madsim))]
721    use std::net::TcpListener;
722    #[cfg(not(madsim))]
723    use std::sync::mpsc;
724    #[cfg(not(madsim))]
725    use std::thread;
726
727    #[cfg(not(madsim))]
728    use risingwave_common::array::StreamChunk;
729    #[cfg(not(madsim))]
730    use risingwave_common::array::stream_chunk::StreamChunkTestExt as _;
731    use risingwave_common::array::{ListValue, VectorVal};
732    use risingwave_common::catalog::Field;
733    use risingwave_common::row::OwnedRow;
734    use risingwave_common::types::{ListType, ScalarImpl, Timestamp};
735    use serde_json::json;
736
737    use super::*;
738    #[cfg(not(madsim))]
739    use crate::sink::log_store::DeliveryFutureManager;
740
741    #[test]
742    fn test_build_schema_flags() {
743        let schema = Schema::new(vec![
744            Field::with_name(DataType::Varchar, "id"),
745            Field::with_name(DataType::Varchar, "body"),
746            Field::with_name(DataType::List(ListType::new(DataType::Varchar)), "tags"),
747            Field::with_name(DataType::Boolean, "flag"),
748            Field::with_name(DataType::Vector(384), "vector"),
749        ]);
750        let generated = build_turbopuffer_schema(
751            &schema,
752            &[1, 2, 3, 4],
753            &parse_column_selection(Some("body,tags"), &schema, &[1, 2, 3, 4]).unwrap(),
754            &parse_column_selection(Some("*"), &schema, &[1, 2, 3, 4]).unwrap(),
755        )
756        .unwrap();
757
758        assert_eq!(generated["body"]["type"], json!("string"));
759        assert_eq!(generated["body"]["filterable"], json!(true));
760        assert_eq!(generated["body"]["full_text_search"], json!(true));
761        assert_eq!(generated["tags"]["type"], json!("[]string"));
762        assert_eq!(generated["tags"]["full_text_search"], json!(true));
763        assert_eq!(generated["vector"]["type"], json!("[384]f32"));
764        assert_eq!(generated["vector"]["ann"], json!(true));
765    }
766
767    #[cfg(not(madsim))]
768    #[tokio::test]
769    async fn test_write_chunk_posts_batched_payload_and_headers() {
770        let (base_url, request_rx, server_thread) = spawn_mock_http_server();
771        let schema = Schema::new(vec![
772            Field::with_name(DataType::Varchar, "id"),
773            Field::with_name(DataType::Varchar, "body"),
774            Field::with_name(DataType::Varchar, "workspace_id"),
775        ]);
776        let payload_indices = vec![1];
777        let generated_schema = build_turbopuffer_schema(
778            &schema,
779            &payload_indices,
780            &parse_column_selection(Some("body"), &schema, &payload_indices).unwrap(),
781            &parse_column_selection(Some("*"), &schema, &payload_indices).unwrap(),
782        )
783        .unwrap();
784        let config = TurbopufferConfig {
785            base_url,
786            namespace: None,
787            namespace_column: Some("workspace_id".to_owned()),
788            api_key: "tpuf_test_key".to_owned(),
789            distance_metric: None,
790            disable_backpressure: Some(true),
791            full_text_search_columns: Some("body".to_owned()),
792            filterable_columns: Some("*".to_owned()),
793            r#type: "upsert".to_owned(),
794        };
795        let mut writer = TurbopufferSinkWriter::new(
796            config,
797            schema,
798            false,
799            0,
800            TurbopufferNamespace::Dynamic { index: 2 },
801            payload_indices,
802            generated_schema,
803        )
804        .unwrap();
805        let chunk = StreamChunk::from_pretty(
806            "T  T        T
807            U- old-id   old_body ns_1
808            U+ new-id   new_body ns_1",
809        );
810        let mut future_manager = DeliveryFutureManager::new(0);
811
812        writer
813            .write_chunk(chunk, future_manager.start_write_chunk(0, 0))
814            .await
815            .unwrap();
816        let request = request_rx.recv().unwrap();
817        server_thread.join().unwrap();
818
819        assert!(request.starts_with("post /v2/namespaces/ns_1 http/1.1"));
820        assert!(request.contains("authorization: bearer tpuf_test_key"));
821        assert!(request.contains("content-type: application/json"));
822
823        let body = request.split("\r\n\r\n").nth(1).unwrap();
824        let body: Value = serde_json::from_str(body).unwrap();
825        assert_eq!(body["disable_backpressure"], json!(true));
826        assert_eq!(body["schema"]["body"]["type"], json!("string"));
827        assert_eq!(body["schema"]["body"]["filterable"], json!(true));
828        assert_eq!(body["schema"]["body"]["full_text_search"], json!(true));
829        assert_eq!(body["deletes"], json!(["old-id"]));
830        assert_eq!(body["upsert_rows"].as_array().unwrap().len(), 1);
831        assert_eq!(body["upsert_rows"][0]["id"], json!("new-id"));
832        assert_eq!(body["upsert_rows"][0]["body"], json!("new_body"));
833        assert!(body.get("distance_metric").is_none());
834    }
835
836    #[test]
837    fn test_decimal_and_serial_schema_types() {
838        assert_eq!(turbopuffer_type(&DataType::Decimal).unwrap(), "float");
839        assert_eq!(turbopuffer_type(&DataType::Serial).unwrap(), "int");
840        assert_eq!(
841            turbopuffer_type(&DataType::List(ListType::new(DataType::Decimal))).unwrap(),
842            "[]float"
843        );
844        assert_eq!(
845            turbopuffer_type(&DataType::List(ListType::new(DataType::Serial))).unwrap(),
846            "[]int"
847        );
848    }
849
850    #[test]
851    fn test_manual_http_sink_schema_and_payload_shape() {
852        let schema = Schema::new(vec![
853            Field::with_name(DataType::Varchar, "id"),
854            Field::with_name(DataType::Varchar, "workspaceId"),
855            Field::with_name(DataType::Varchar, "inboxFeedItemId"),
856            Field::with_name(DataType::Varchar, "body"),
857            Field::with_name(
858                DataType::List(ListType::new(DataType::Varchar)),
859                "noteContents",
860            ),
861            Field::with_name(DataType::Varchar, "communityMemberHandle"),
862            Field::with_name(DataType::Varchar, "communityMemberIdentifier"),
863            Field::with_name(DataType::Varchar, "inboxFeedItemTitle"),
864            Field::with_name(DataType::Boolean, "isInboxFeedItemStarred"),
865            Field::with_name(DataType::Boolean, "isInboxFeedItemAnswered"),
866            Field::with_name(DataType::Timestamp, "inboxFeedItemPreviewTimestamp"),
867            Field::with_name(DataType::Timestamp, "inboxFeedItemPublishTimestamp"),
868            Field::with_name(DataType::Int64, "inboxFeedItemAuthorInstagramFollowerCount"),
869            Field::with_name(DataType::Int64, "inboxFeedItemAuthorTikTokFollowerCount"),
870            Field::with_name(
871                DataType::List(ListType::new(DataType::Varchar)),
872                "attributes",
873            ),
874            Field::with_name(DataType::Varchar, "threadId"),
875            Field::with_name(DataType::Vector(384), "vector"),
876        ]);
877        let attribute_indices = (2..schema.len()).collect_vec();
878        let full_text_search_columns = parse_column_selection(
879            Some(
880                "body,noteContents,communityMemberHandle,communityMemberIdentifier,inboxFeedItemTitle",
881            ),
882            &schema,
883            &attribute_indices,
884        )
885        .unwrap();
886        let filterable_columns =
887            parse_column_selection(Some("*"), &schema, &attribute_indices).unwrap();
888        let generated_schema = build_turbopuffer_schema(
889            &schema,
890            &attribute_indices,
891            &full_text_search_columns,
892            &filterable_columns,
893        )
894        .unwrap();
895
896        assert_eq!(
897            generated_schema,
898            json!({
899                "inboxFeedItemId": {"type": "string", "filterable": true},
900                "body": {"type": "string", "filterable": true, "full_text_search": true},
901                "noteContents": {"type": "[]string", "filterable": true, "full_text_search": true},
902                "communityMemberHandle": {"type": "string", "filterable": true, "full_text_search": true},
903                "communityMemberIdentifier": {"type": "string", "filterable": true, "full_text_search": true},
904                "inboxFeedItemTitle": {"type": "string", "filterable": true, "full_text_search": true},
905                "isInboxFeedItemStarred": {"type": "bool", "filterable": true},
906                "isInboxFeedItemAnswered": {"type": "bool", "filterable": true},
907                "inboxFeedItemPreviewTimestamp": {"type": "datetime", "filterable": true},
908                "inboxFeedItemPublishTimestamp": {"type": "datetime", "filterable": true},
909                "inboxFeedItemAuthorInstagramFollowerCount": {"type": "int", "filterable": true},
910                "inboxFeedItemAuthorTikTokFollowerCount": {"type": "int", "filterable": true},
911                "attributes": {"type": "[]string", "filterable": true},
912                "threadId": {"type": "string", "filterable": true},
913                "vector": {"type": "[384]f32", "filterable": true, "ann": true}
914            })
915        );
916
917        let config = TurbopufferConfig {
918            base_url: "http://127.0.0.1:0".to_owned(),
919            namespace: None,
920            namespace_column: Some("workspaceId".to_owned()),
921            api_key: "tpuf_test_key".to_owned(),
922            distance_metric: Some("cosine_distance".to_owned()),
923            disable_backpressure: Some(true),
924            full_text_search_columns: Some("body,noteContents,communityMemberHandle,communityMemberIdentifier,inboxFeedItemTitle".to_owned()),
925            filterable_columns: Some("*".to_owned()),
926            r#type: "upsert".to_owned(),
927        };
928        let writer = TurbopufferSinkWriter::new(
929            config,
930            schema,
931            false,
932            0,
933            TurbopufferNamespace::Dynamic { index: 1 },
934            attribute_indices,
935            generated_schema.clone(),
936        )
937        .unwrap();
938        let vector =
939            VectorVal::from_text(&format!("[{}]", vec!["0.25"; 384].join(",")), 384).unwrap();
940        let row = OwnedRow::new(vec![
941            Some(ScalarImpl::Utf8("doc-1".into())),
942            Some(ScalarImpl::Utf8("workspace-1".into())),
943            Some(ScalarImpl::Utf8("item-1".into())),
944            Some(ScalarImpl::Utf8("body text".into())),
945            Some(ScalarImpl::List(ListValue::from_iter(["note a", "note b"]))),
946            Some(ScalarImpl::Utf8("@member".into())),
947            Some(ScalarImpl::Utf8("member-1".into())),
948            Some(ScalarImpl::Utf8("title".into())),
949            Some(ScalarImpl::Bool(true)),
950            Some(ScalarImpl::Bool(false)),
951            Some(ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(
952                1_781_582_706,
953                0,
954            ))),
955            Some(ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(
956                1_781_598_707,
957                0,
958            ))),
959            Some(ScalarImpl::Int64(12345)),
960            Some(ScalarImpl::Int64(67890)),
961            Some(ScalarImpl::List(ListValue::from_iter(["important", "vip"]))),
962            Some(ScalarImpl::Utf8("thread-1".into())),
963            Some(ScalarImpl::Vector(vector)),
964        ]);
965        let id = writer.id_for_row(&row).unwrap();
966        let upsert_row = writer.upsert_row(&row, id).unwrap();
967        let body = writer.request_body(vec![upsert_row], Vec::new());
968
969        assert_eq!(body["distance_metric"], json!("cosine_distance"));
970        assert_eq!(body["disable_backpressure"], json!(true));
971        assert_eq!(body["schema"], generated_schema);
972        assert_eq!(body["upsert_rows"][0]["id"], json!("doc-1"));
973        assert_eq!(body["upsert_rows"][0]["body"], json!("body text"));
974        assert_eq!(
975            body["upsert_rows"][0]["noteContents"],
976            json!(["note a", "note b"])
977        );
978        assert_eq!(
979            body["upsert_rows"][0]["isInboxFeedItemStarred"],
980            json!(true)
981        );
982        assert_eq!(
983            body["upsert_rows"][0]["isInboxFeedItemAnswered"],
984            json!(false)
985        );
986        assert_eq!(
987            body["upsert_rows"][0]["inboxFeedItemAuthorInstagramFollowerCount"],
988            json!(12345)
989        );
990        assert_eq!(
991            body["upsert_rows"][0]["inboxFeedItemPreviewTimestamp"],
992            json!("2026-06-16 04:05:06.000000")
993        );
994        assert_eq!(
995            body["upsert_rows"][0]["vector"].as_array().unwrap().len(),
996            384
997        );
998        assert_eq!(body["upsert_rows"][0]["vector"][0], json!(0.25));
999    }
1000
1001    #[cfg(not(madsim))]
1002    fn spawn_mock_http_server() -> (String, mpsc::Receiver<String>, thread::JoinHandle<()>) {
1003        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1004        let addr = listener.local_addr().unwrap();
1005        let (request_tx, request_rx) = mpsc::channel();
1006        let server_thread = thread::spawn(move || {
1007            let (mut stream, _) = listener.accept().unwrap();
1008            let mut buf = Vec::new();
1009            let header_end = loop {
1010                let mut tmp = [0; 1024];
1011                let read = stream.read(&mut tmp).unwrap();
1012                assert_ne!(read, 0);
1013                buf.extend_from_slice(&tmp[..read]);
1014                if let Some(header_end) = find_header_end(&buf) {
1015                    break header_end;
1016                }
1017            };
1018            let headers = String::from_utf8_lossy(&buf[..header_end]);
1019            let content_length = headers
1020                .lines()
1021                .find_map(|line| {
1022                    let (name, value) = line.split_once(':')?;
1023                    name.eq_ignore_ascii_case("content-length")
1024                        .then(|| value.trim().parse::<usize>().unwrap())
1025                })
1026                .unwrap();
1027            while buf.len() < header_end + 4 + content_length {
1028                let mut tmp = [0; 1024];
1029                let read = stream.read(&mut tmp).unwrap();
1030                assert_ne!(read, 0);
1031                buf.extend_from_slice(&tmp[..read]);
1032            }
1033            let request = String::from_utf8(buf[..header_end + 4 + content_length].to_vec())
1034                .expect("HTTP request should be utf8");
1035            request_tx.send(request.to_lowercase()).unwrap();
1036            stream
1037                .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n")
1038                .unwrap();
1039        });
1040        (format!("http://{}", addr), request_rx, server_thread)
1041    }
1042
1043    #[cfg(not(madsim))]
1044    fn find_header_end(buf: &[u8]) -> Option<usize> {
1045        buf.windows(4).position(|window| window == b"\r\n\r\n")
1046    }
1047}