risingwave_connector/sink/
doris_starrocks_connector.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::mem;
16use core::time::Duration;
17use std::collections::HashMap;
18use std::convert::Infallible;
19
20use anyhow::{Context, anyhow};
21use base64::Engine;
22use base64::engine::general_purpose;
23use bytes::{BufMut, Bytes, BytesMut};
24use futures::StreamExt;
25use reqwest::header::{HeaderName, HeaderValue};
26use reqwest::{Body, Client, Method, Request, RequestBuilder, Response, StatusCode, redirect};
27use tokio::sync::mpsc::UnboundedSender;
28use tokio::task::JoinHandle;
29use url::Url;
30
31use super::{Result, SinkError};
32
33const BUFFER_SIZE: usize = 64 * 1024;
34const MIN_CHUNK_SIZE: usize = BUFFER_SIZE - 1024;
35pub(crate) const DORIS_SUCCESS_STATUS: [&str; 2] = ["Success", "Publish Timeout"];
36pub(crate) const STARROCKS_SUCCESS_STATUS: [&str; 1] = ["OK"];
37pub(crate) const DORIS_DELETE_SIGN: &str = "__DORIS_DELETE_SIGN__";
38pub(crate) const STARROCKS_DELETE_SIGN: &str = "__op";
39
40pub(crate) const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
41const LOCALHOST: &str = "localhost";
42const LOCALHOST_IP: &str = "127.0.0.1";
43pub struct HeaderBuilder {
44    header: HashMap<String, String>,
45}
46impl Default for HeaderBuilder {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51impl HeaderBuilder {
52    pub fn new() -> Self {
53        Self {
54            header: HashMap::default(),
55        }
56    }
57
58    pub fn add_common_header(mut self) -> Self {
59        self.header
60            .insert("expect".to_owned(), "100-continue".to_owned());
61        self
62    }
63
64    /// The method is temporarily not in use, reserved for later use in 2PC.
65    /// Doris will generate a default, non-repeating label.
66    pub fn set_label(mut self, label: String) -> Self {
67        self.header.insert("label".to_owned(), label);
68        self
69    }
70
71    pub fn set_columns_name(mut self, columns_name: Vec<&str>) -> Self {
72        let columns_name_str = columns_name.join(",");
73        self.header.insert("columns".to_owned(), columns_name_str);
74        self
75    }
76
77    /// This method is only called during upsert operations.
78    pub fn add_hidden_column(mut self) -> Self {
79        self.header
80            .insert("hidden_columns".to_owned(), DORIS_DELETE_SIGN.to_owned());
81        self
82    }
83
84    /// The method is temporarily not in use, reserved for later use in 2PC.
85    /// Only use in Doris
86    pub fn enable_2_pc(mut self) -> Self {
87        self.header
88            .insert("two_phase_commit".to_owned(), "true".to_owned());
89        self
90    }
91
92    pub fn set_user_password(mut self, user: String, password: String) -> Self {
93        let auth = format!(
94            "Basic {}",
95            general_purpose::STANDARD.encode(format!("{}:{}", user, password))
96        );
97        self.header.insert("Authorization".to_owned(), auth);
98        self
99    }
100
101    /// The method is temporarily not in use, reserved for later use in 2PC.
102    /// Only use in Doris
103    pub fn set_txn_id(mut self, txn_id: i64) -> Self {
104        self.header
105            .insert("txn_operation".to_owned(), txn_id.to_string());
106        self
107    }
108
109    /// The method is temporarily not in use, reserved for later use in 2PC.
110    /// Only use in Doris
111    pub fn add_commit(mut self) -> Self {
112        self.header
113            .insert("txn_operation".to_owned(), "commit".to_owned());
114        self
115    }
116
117    /// The method is temporarily not in use, reserved for later use in 2PC.
118    /// Only use in Doris
119    pub fn add_abort(mut self) -> Self {
120        self.header
121            .insert("txn_operation".to_owned(), "abort".to_owned());
122        self
123    }
124
125    pub fn add_json_format(mut self) -> Self {
126        self.header.insert("format".to_owned(), "json".to_owned());
127        self
128    }
129
130    /// Only use in Doris
131    pub fn add_read_json_by_line(mut self) -> Self {
132        self.header
133            .insert("read_json_by_line".to_owned(), "true".to_owned());
134        self
135    }
136
137    /// Only use in Starrocks
138    pub fn add_strip_outer_array(mut self) -> Self {
139        self.header
140            .insert("strip_outer_array".to_owned(), "true".to_owned());
141        self
142    }
143
144    /// Only use in Starrocks
145    pub fn set_partial_update(mut self, partial_update: Option<String>) -> Self {
146        self.header.insert(
147            "partial_update".to_owned(),
148            partial_update.unwrap_or_else(|| "false".to_owned()),
149        );
150        self
151    }
152
153    /// Only use in Doris
154    pub fn set_partial_columns(mut self, partial_columns: Option<String>) -> Self {
155        self.header.insert(
156            "partial_columns".to_owned(),
157            partial_columns.unwrap_or_else(|| "false".to_owned()),
158        );
159        self
160    }
161
162    /// Only used in Starrocks Transaction API
163    pub fn set_db(mut self, db: String) -> Self {
164        self.header.insert("db".to_owned(), db);
165        self
166    }
167
168    /// Only used in Starrocks Transaction API
169    pub fn set_table(mut self, table: String) -> Self {
170        self.header.insert("table".to_owned(), table);
171        self
172    }
173
174    pub fn build(self) -> HashMap<String, String> {
175        self.header
176    }
177}
178
179/// Try getting BE url from a redirected response, returning `Ok(None)` indicates this request does
180/// not redirect.
181///
182/// The reason we handle the redirection manually is that if we let `reqwest` handle the redirection
183/// automatically, it will remove sensitive headers (such as Authorization) during the redirection,
184/// and there's no way to prevent this behavior.
185fn try_get_be_url(resp: &Response, fe_host: &str) -> Result<Option<Url>> {
186    match resp.status() {
187        StatusCode::TEMPORARY_REDIRECT => {
188            let be_url = resp
189                .headers()
190                .get("location")
191                .ok_or_else(|| {
192                    SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url in header",))
193                })?
194                .to_str()
195                .context("Can't get doris BE url in header")
196                .map_err(SinkError::DorisStarrocksConnect)?
197                .to_owned();
198
199            let mut parsed_be_url = Url::parse(&be_url)
200                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
201
202            if fe_host != LOCALHOST && fe_host != LOCALHOST_IP {
203                let be_host = parsed_be_url.host_str().ok_or_else(|| {
204                    SinkError::DorisStarrocksConnect(anyhow!("Can't get be host from url"))
205                })?;
206
207                if be_host == LOCALHOST || be_host == LOCALHOST_IP {
208                    // if be host is 127.0.0.1, we may can't connect to it directly,
209                    // so replace it with fe host
210                    parsed_be_url
211                        .set_host(Some(fe_host))
212                        .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
213                }
214            }
215            Ok(Some(parsed_be_url))
216        }
217        StatusCode::OK => {
218            // Some of the `StarRocks` transactional APIs will respond directly from FE. For example,
219            // the request to `/api/transaction/commit` endpoint does not seem to redirect to BE.
220            // In this case, the request should be treated as finished.
221            Ok(None)
222        }
223        _ => Err(SinkError::DorisStarrocksConnect(anyhow!(
224            "Can't get doris BE url",
225        ))),
226    }
227}
228
229pub struct InserterInnerBuilder {
230    url: String,
231    header: HashMap<String, String>,
232    #[expect(dead_code)]
233    sender: Option<Sender>,
234    fe_host: String,
235    stream_load_http_timeout: Duration,
236}
237impl InserterInnerBuilder {
238    pub fn new(
239        url: String,
240        db: String,
241        table: String,
242        header: HashMap<String, String>,
243        stream_load_http_timeout_ms: u64,
244    ) -> Result<Self> {
245        let fe_host = Url::parse(&url)
246            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
247            .host_str()
248            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
249            .to_owned();
250        let url = format!("{}/api/{}/{}/_stream_load", url, db, table);
251        let stream_load_http_timeout = Duration::from_millis(stream_load_http_timeout_ms);
252
253        Ok(Self {
254            url,
255            sender: None,
256            header,
257            fe_host,
258            stream_load_http_timeout,
259        })
260    }
261
262    fn build_request(&self, uri: String) -> Result<RequestBuilder> {
263        let client = Client::builder()
264            .pool_idle_timeout(POOL_IDLE_TIMEOUT)
265            .redirect(redirect::Policy::none()) // we handle redirect by ourselves
266            .build()
267            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
268
269        let mut builder = client.put(uri);
270        for (k, v) in &self.header {
271            builder = builder.header(k, v);
272        }
273        Ok(builder)
274    }
275
276    pub async fn build(&self) -> Result<InserterInner> {
277        let builder = self.build_request(self.url.clone())?;
278        let resp = builder
279            .send()
280            .await
281            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
282
283        let be_url = try_get_be_url(&resp, self.fe_host.as_str())?
284            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url",)))?;
285
286        let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
287        let body = Body::wrap_stream(
288            tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
289        );
290        let builder = self.build_request(be_url.into())?.body(body);
291
292        let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
293            let response = builder
294                .send()
295                .await
296                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
297            let status = response.status();
298            let raw = response
299                .bytes()
300                .await
301                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
302                .into();
303
304            if status == StatusCode::OK {
305                Ok(raw)
306            } else {
307                let response_body = String::from_utf8(raw).map_err(|err| {
308                    SinkError::DorisStarrocksConnect(
309                        anyhow!(err).context("failed to parse response body"),
310                    )
311                })?;
312                Err(SinkError::DorisStarrocksConnect(anyhow!(
313                    "Failed connection {:?},{:?}",
314                    status,
315                    response_body
316                )))
317            }
318        });
319        Ok(InserterInner::new(
320            sender,
321            handle,
322            self.stream_load_http_timeout,
323        ))
324    }
325}
326
327type Sender = UnboundedSender<Bytes>;
328
329pub struct InserterInner {
330    sender: Option<Sender>,
331    join_handle: JoinHandle<Result<Vec<u8>>>,
332    buffer: BytesMut,
333    stream_load_http_timeout: Duration,
334}
335impl InserterInner {
336    pub fn new(
337        sender: Sender,
338        join_handle: JoinHandle<Result<Vec<u8>>>,
339        stream_load_http_timeout: Duration,
340    ) -> Self {
341        Self {
342            sender: Some(sender),
343            join_handle,
344            buffer: BytesMut::with_capacity(BUFFER_SIZE),
345            stream_load_http_timeout,
346        }
347    }
348
349    async fn send_chunk(&mut self) -> Result<()> {
350        if self.sender.is_none() {
351            return Ok(());
352        }
353
354        let chunk = mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE));
355
356        match self.sender.as_mut().unwrap().send(chunk.freeze()) {
357            Err(_e) => {
358                self.sender.take();
359                self.wait_handle().await?;
360
361                Err(SinkError::DorisStarrocksConnect(anyhow!("channel closed")))
362            }
363            _ => Ok(()),
364        }
365    }
366
367    pub async fn write(&mut self, data: Bytes) -> Result<()> {
368        self.buffer.put_slice(&data);
369        if self.buffer.len() >= MIN_CHUNK_SIZE {
370            self.send_chunk().await?;
371        }
372        Ok(())
373    }
374
375    async fn wait_handle(&mut self) -> Result<Vec<u8>> {
376        let res = match tokio::time::timeout(self.stream_load_http_timeout, &mut self.join_handle)
377            .await
378        {
379            Ok(res) => res.map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))??,
380            Err(err) => return Err(SinkError::DorisStarrocksConnect(anyhow!(err))),
381        };
382        Ok(res)
383    }
384
385    pub async fn finish(mut self) -> Result<Vec<u8>> {
386        if !self.buffer.is_empty() {
387            self.send_chunk().await?;
388        }
389        self.sender = None;
390        self.wait_handle().await
391    }
392}
393
394enum StreamLoadResponse {
395    BeRequest(Request),
396    HttpResponse(Response),
397}
398
399/// Send the request and handle redirection if any.
400/// The reason we handle the redirection manually is that if we let `reqwest` handle the redirection
401/// automatically, it will remove sensitive headers (such as Authorization) during the redirection,
402/// and there's no way to prevent this behavior.
403/// Please note, the FE address that user specified might be a FE follower not the leader, in this case,
404/// the follower FE will redirect request to leader FE and then to BE.
405async fn send_stream_load_request(
406    client: Client,
407    mut request: Request,
408    fe_host: &str,
409) -> Result<StreamLoadResponse> {
410    // possible redirection paths:
411    // RW <-> follower FE -> leader FE -> BE
412    // RW <-> leader FE -> BE
413    // RW <-> leader FE
414    for _ in 0..2 {
415        let original_http_port = request.url().port();
416        let mut request_for_redirection = request
417            .try_clone()
418            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't clone request")))?;
419        let resp = client.execute(request).await.map_err(|err| {
420            SinkError::DorisStarrocksConnect(
421                anyhow!(err).context("sending stream load request failed"),
422            )
423        })?;
424        let be_url = try_get_be_url(&resp, fe_host)?;
425        match be_url {
426            Some(be_url) => {
427                // we used an unconventional method to detect if we are currently redirecting to FE leader, i.e.,
428                // by comparing the port of the redirected url with that of the original request, if they are same, we consider
429                // this is a FE address. Because in practice, no one would deploy their `StarRocks` cluster with the same
430                // http port for both FE and BE. However, this is a potentially problematic assumption,
431                // we may investigate a better way to do this. For example, we could use the `show backends` command to check
432                // if the host of the redirected url is in the list. However, `show backends` requires
433                // the system-level privilege, which could break the backward compatibility.
434                let redirected_port = be_url.port();
435                *request_for_redirection.url_mut() = be_url;
436                if redirected_port == original_http_port {
437                    // redirected to FE, continue another round.
438                    request = request_for_redirection;
439                } else {
440                    // we got BE address here
441                    return Ok(StreamLoadResponse::BeRequest(request_for_redirection));
442                }
443            }
444            None => return Ok(StreamLoadResponse::HttpResponse(resp)),
445        }
446    }
447    Err(SinkError::DorisStarrocksConnect(anyhow!(
448        "redirection occur more than twice when sending stream load request"
449    )))
450}
451
452pub struct MetaRequestSender {
453    client: Client,
454    request: Request,
455    fe_host: String,
456}
457
458impl MetaRequestSender {
459    pub fn new(client: Client, request: Request, fe_host: String) -> Self {
460        Self {
461            client,
462            request,
463            fe_host,
464        }
465    }
466
467    pub async fn send(self) -> Result<Bytes> {
468        match send_stream_load_request(self.client.clone(), self.request, &self.fe_host).await? {
469            StreamLoadResponse::BeRequest(be_request) => self
470                .client
471                .execute(be_request)
472                .await
473                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
474                .bytes()
475                .await
476                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
477            StreamLoadResponse::HttpResponse(resp) => resp
478                .bytes()
479                .await
480                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
481        }
482    }
483}
484
485pub struct StarrocksTxnRequestBuilder {
486    url_begin: String,
487    url_load: String,
488    url_prepare: String,
489    url_commit: String,
490    url_rollback: String,
491    header: HashMap<String, String>,
492    fe_host: String,
493    stream_load_http_timeout: Duration,
494    // The `reqwest` crate suggests us reuse the Client, and we don't need make it Arc, because it
495    // already uses an Arc internally.
496    client: Client,
497}
498
499impl StarrocksTxnRequestBuilder {
500    pub fn new(
501        url: String,
502        header: HashMap<String, String>,
503        stream_load_http_timeout_ms: u64,
504    ) -> Result<Self> {
505        let fe_host = Url::parse(&url)
506            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
507            .host_str()
508            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
509            .to_owned();
510
511        let url_begin = format!("{}/api/transaction/begin", url);
512        let url_load = format!("{}/api/transaction/load", url);
513        let url_prepare = format!("{}/api/transaction/prepare", url);
514        let url_commit = format!("{}/api/transaction/commit", url);
515        let url_rollback = format!("{}/api/transaction/rollback", url);
516
517        let stream_load_http_timeout = Duration::from_millis(stream_load_http_timeout_ms);
518
519        let client = Client::builder()
520            .pool_idle_timeout(POOL_IDLE_TIMEOUT)
521            .redirect(redirect::Policy::none())
522            .build()
523            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
524
525        Ok(Self {
526            url_begin,
527            url_load,
528            url_prepare,
529            url_commit,
530            url_rollback,
531            header,
532            fe_host,
533            stream_load_http_timeout,
534            client,
535        })
536    }
537
538    fn build_request(&self, uri: String, method: Method, label: String) -> Result<Request> {
539        let parsed_url =
540            Url::parse(&uri).map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
541        let mut request = Request::new(method, parsed_url);
542
543        if uri != self.url_load {
544            // Set timeout for non-load requests; load requests' timeout is controlled by `tokio::timeout`
545            *request.timeout_mut() = Some(self.stream_load_http_timeout);
546        }
547
548        let header = request.headers_mut();
549        for (k, v) in &self.header {
550            header.insert(
551                HeaderName::try_from(k)
552                    .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
553                HeaderValue::try_from(v)
554                    .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
555            );
556        }
557        header.insert(
558            "label",
559            HeaderValue::try_from(label)
560                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
561        );
562
563        Ok(request)
564    }
565
566    pub fn build_begin_request_sender(&self, label: String) -> Result<MetaRequestSender> {
567        let request = self.build_request(self.url_begin.clone(), Method::POST, label)?;
568        Ok(MetaRequestSender::new(
569            self.client.clone(),
570            request,
571            self.fe_host.clone(),
572        ))
573    }
574
575    pub fn build_prepare_request_sender(&self, label: String) -> Result<MetaRequestSender> {
576        let request = self.build_request(self.url_prepare.clone(), Method::POST, label)?;
577        Ok(MetaRequestSender::new(
578            self.client.clone(),
579            request,
580            self.fe_host.clone(),
581        ))
582    }
583
584    pub fn build_commit_request_sender(&self, label: String) -> Result<MetaRequestSender> {
585        let request = self.build_request(self.url_commit.clone(), Method::POST, label)?;
586        Ok(MetaRequestSender::new(
587            self.client.clone(),
588            request,
589            self.fe_host.clone(),
590        ))
591    }
592
593    pub fn build_rollback_request_sender(&self, label: String) -> Result<MetaRequestSender> {
594        let request = self.build_request(self.url_rollback.clone(), Method::POST, label)?;
595        Ok(MetaRequestSender::new(
596            self.client.clone(),
597            request,
598            self.fe_host.clone(),
599        ))
600    }
601
602    pub async fn build_txn_inserter(&self, label: String) -> Result<InserterInner> {
603        let request = self.build_request(self.url_load.clone(), Method::PUT, label.clone())?;
604        let mut be_request =
605            match send_stream_load_request(self.client.clone(), request, &self.fe_host).await? {
606                StreamLoadResponse::BeRequest(be_request) => be_request,
607                StreamLoadResponse::HttpResponse(resp) => {
608                    // If we get a response here, it should be from BE, so we extract the URL
609                    // and create a new request based on it.
610                    let url = resp.url().clone();
611                    self.build_request(url.into(), Method::PUT, label)?
612                }
613            };
614        let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
615        let body = Body::wrap_stream(
616            tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
617        );
618        *be_request.body_mut() = Some(body);
619
620        let client = self.client.clone();
621        let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
622            let response = client
623                .execute(be_request)
624                .await
625                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
626
627            let status = response.status();
628            let raw = response
629                .bytes()
630                .await
631                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
632                .into();
633
634            if status == StatusCode::OK {
635                Ok(raw)
636            } else {
637                let response_body = String::from_utf8(raw).map_err(|err| {
638                    SinkError::DorisStarrocksConnect(
639                        anyhow!(err).context("failed to parse response body"),
640                    )
641                })?;
642                Err(SinkError::DorisStarrocksConnect(anyhow!(
643                    "Failed connection {:?},{:?}",
644                    status,
645                    response_body
646                )))
647            }
648        });
649        Ok(InserterInner::new(
650            sender,
651            handle,
652            self.stream_load_http_timeout,
653        ))
654    }
655}