risingwave_connector/sink/
doris_starrocks_connector.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::mem;
16use core::time::Duration;
17use std::collections::HashMap;
18use std::convert::Infallible;
19
20use anyhow::{Context, anyhow};
21use base64::Engine;
22use base64::engine::general_purpose;
23use bytes::{BufMut, Bytes, BytesMut};
24use futures::StreamExt;
25use reqwest::header::{HeaderName, HeaderValue};
26use reqwest::{Body, Client, Method, Request, RequestBuilder, Response, StatusCode, redirect};
27use tokio::sync::mpsc::UnboundedSender;
28use tokio::task::JoinHandle;
29use url::Url;
30
31use super::{Result, SinkError};
32
33const BUFFER_SIZE: usize = 64 * 1024;
34const MIN_CHUNK_SIZE: usize = BUFFER_SIZE - 1024;
35pub(crate) const DORIS_SUCCESS_STATUS: [&str; 2] = ["Success", "Publish Timeout"];
36pub(crate) const STARROCKS_SUCCESS_STATUS: [&str; 1] = ["OK"];
37pub(crate) const DORIS_DELETE_SIGN: &str = "__DORIS_DELETE_SIGN__";
38pub(crate) const STARROCKS_DELETE_SIGN: &str = "__op";
39
40const WAIT_HANDDLE_TIMEOUT: Duration = Duration::from_secs(10);
41pub(crate) const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
42const LOCALHOST: &str = "localhost";
43const LOCALHOST_IP: &str = "127.0.0.1";
44pub struct HeaderBuilder {
45    header: HashMap<String, String>,
46}
47impl Default for HeaderBuilder {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52impl HeaderBuilder {
53    pub fn new() -> Self {
54        Self {
55            header: HashMap::default(),
56        }
57    }
58
59    pub fn add_common_header(mut self) -> Self {
60        self.header
61            .insert("expect".to_owned(), "100-continue".to_owned());
62        self
63    }
64
65    /// The method is temporarily not in use, reserved for later use in 2PC.
66    /// Doris will generate a default, non-repeating label.
67    pub fn set_label(mut self, label: String) -> Self {
68        self.header.insert("label".to_owned(), label);
69        self
70    }
71
72    pub fn set_columns_name(mut self, columns_name: Vec<&str>) -> Self {
73        let columns_name_str = columns_name.join(",");
74        self.header.insert("columns".to_owned(), columns_name_str);
75        self
76    }
77
78    /// This method is only called during upsert operations.
79    pub fn add_hidden_column(mut self) -> Self {
80        self.header
81            .insert("hidden_columns".to_owned(), DORIS_DELETE_SIGN.to_owned());
82        self
83    }
84
85    /// The method is temporarily not in use, reserved for later use in 2PC.
86    /// Only use in Doris
87    pub fn enable_2_pc(mut self) -> Self {
88        self.header
89            .insert("two_phase_commit".to_owned(), "true".to_owned());
90        self
91    }
92
93    pub fn set_user_password(mut self, user: String, password: String) -> Self {
94        let auth = format!(
95            "Basic {}",
96            general_purpose::STANDARD.encode(format!("{}:{}", user, password))
97        );
98        self.header.insert("Authorization".to_owned(), auth);
99        self
100    }
101
102    /// The method is temporarily not in use, reserved for later use in 2PC.
103    /// Only use in Doris
104    pub fn set_txn_id(mut self, txn_id: i64) -> Self {
105        self.header
106            .insert("txn_operation".to_owned(), txn_id.to_string());
107        self
108    }
109
110    /// The method is temporarily not in use, reserved for later use in 2PC.
111    /// Only use in Doris
112    pub fn add_commit(mut self) -> Self {
113        self.header
114            .insert("txn_operation".to_owned(), "commit".to_owned());
115        self
116    }
117
118    /// The method is temporarily not in use, reserved for later use in 2PC.
119    /// Only use in Doris
120    pub fn add_abort(mut self) -> Self {
121        self.header
122            .insert("txn_operation".to_owned(), "abort".to_owned());
123        self
124    }
125
126    pub fn add_json_format(mut self) -> Self {
127        self.header.insert("format".to_owned(), "json".to_owned());
128        self
129    }
130
131    /// Only use in Doris
132    pub fn add_read_json_by_line(mut self) -> Self {
133        self.header
134            .insert("read_json_by_line".to_owned(), "true".to_owned());
135        self
136    }
137
138    /// Only use in Starrocks
139    pub fn add_strip_outer_array(mut self) -> Self {
140        self.header
141            .insert("strip_outer_array".to_owned(), "true".to_owned());
142        self
143    }
144
145    /// Only use in Starrocks
146    pub fn set_partial_update(mut self, partial_update: Option<String>) -> Self {
147        self.header.insert(
148            "partial_update".to_owned(),
149            partial_update.unwrap_or_else(|| "false".to_owned()),
150        );
151        self
152    }
153
154    /// Only use in Doris
155    pub fn set_partial_columns(mut self, partial_columns: Option<String>) -> Self {
156        self.header.insert(
157            "partial_columns".to_owned(),
158            partial_columns.unwrap_or_else(|| "false".to_owned()),
159        );
160        self
161    }
162
163    /// Only used in Starrocks Transaction API
164    pub fn set_db(mut self, db: String) -> Self {
165        self.header.insert("db".to_owned(), db);
166        self
167    }
168
169    /// Only used in Starrocks Transaction API
170    pub fn set_table(mut self, table: String) -> Self {
171        self.header.insert("table".to_owned(), table);
172        self
173    }
174
175    pub fn build(self) -> HashMap<String, String> {
176        self.header
177    }
178}
179
180/// Try getting BE url from a redirected response, returning `Ok(None)` indicates this request does
181/// not redirect.
182///
183/// The reason we handle the redirection manually is that if we let `reqwest` handle the redirection
184/// automatically, it will remove sensitive headers (such as Authorization) during the redirection,
185/// and there's no way to prevent this behavior.
186fn try_get_be_url(resp: &Response, fe_host: &str) -> Result<Option<Url>> {
187    match resp.status() {
188        StatusCode::TEMPORARY_REDIRECT => {
189            let be_url = resp
190                .headers()
191                .get("location")
192                .ok_or_else(|| {
193                    SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url in header",))
194                })?
195                .to_str()
196                .context("Can't get doris BE url in header")
197                .map_err(SinkError::DorisStarrocksConnect)?
198                .to_owned();
199
200            let mut parsed_be_url = Url::parse(&be_url)
201                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
202
203            if fe_host != LOCALHOST && fe_host != LOCALHOST_IP {
204                let be_host = parsed_be_url.host_str().ok_or_else(|| {
205                    SinkError::DorisStarrocksConnect(anyhow!("Can't get be host from url"))
206                })?;
207
208                if be_host == LOCALHOST || be_host == LOCALHOST_IP {
209                    // if be host is 127.0.0.1, we may can't connect to it directly,
210                    // so replace it with fe host
211                    parsed_be_url
212                        .set_host(Some(fe_host))
213                        .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
214                }
215            }
216            Ok(Some(parsed_be_url))
217        }
218        StatusCode::OK => {
219            // Some of the `StarRocks` transactional APIs will respond directly from FE. For example,
220            // the request to `/api/transaction/commit` endpoint does not seem to redirect to BE.
221            // In this case, the request should be treated as finished.
222            Ok(None)
223        }
224        _ => Err(SinkError::DorisStarrocksConnect(anyhow!(
225            "Can't get doris BE url",
226        ))),
227    }
228}
229
230pub struct InserterInnerBuilder {
231    url: String,
232    header: HashMap<String, String>,
233    #[expect(dead_code)]
234    sender: Option<Sender>,
235    fe_host: String,
236}
237impl InserterInnerBuilder {
238    pub fn new(
239        url: String,
240        db: String,
241        table: String,
242        header: HashMap<String, String>,
243    ) -> Result<Self> {
244        let fe_host = Url::parse(&url)
245            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
246            .host_str()
247            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
248            .to_owned();
249        let url = format!("{}/api/{}/{}/_stream_load", url, db, table);
250
251        Ok(Self {
252            url,
253            sender: None,
254            header,
255            fe_host,
256        })
257    }
258
259    fn build_request(&self, uri: String) -> Result<RequestBuilder> {
260        let client = Client::builder()
261            .pool_idle_timeout(POOL_IDLE_TIMEOUT)
262            .redirect(redirect::Policy::none()) // we handle redirect by ourselves
263            .build()
264            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
265
266        let mut builder = client.put(uri);
267        for (k, v) in &self.header {
268            builder = builder.header(k, v);
269        }
270        Ok(builder)
271    }
272
273    pub async fn build(&self) -> Result<InserterInner> {
274        let builder = self.build_request(self.url.clone())?;
275        let resp = builder
276            .send()
277            .await
278            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
279
280        let be_url = try_get_be_url(&resp, self.fe_host.as_str())?
281            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url",)))?;
282
283        let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
284        let body = Body::wrap_stream(
285            tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
286        );
287        let builder = self.build_request(be_url.into())?.body(body);
288
289        let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
290            let response = builder
291                .send()
292                .await
293                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
294            let status = response.status();
295            let raw = response
296                .bytes()
297                .await
298                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
299                .into();
300
301            if status == StatusCode::OK {
302                Ok(raw)
303            } else {
304                let response_body = String::from_utf8(raw).map_err(|err| {
305                    SinkError::DorisStarrocksConnect(
306                        anyhow!(err).context("failed to parse response body"),
307                    )
308                })?;
309                Err(SinkError::DorisStarrocksConnect(anyhow!(
310                    "Failed connection {:?},{:?}",
311                    status,
312                    response_body
313                )))
314            }
315        });
316        Ok(InserterInner::new(sender, handle, WAIT_HANDDLE_TIMEOUT))
317    }
318}
319
320type Sender = UnboundedSender<Bytes>;
321
322pub struct InserterInner {
323    sender: Option<Sender>,
324    join_handle: JoinHandle<Result<Vec<u8>>>,
325    buffer: BytesMut,
326    stream_load_http_timeout: Duration,
327}
328impl InserterInner {
329    pub fn new(
330        sender: Sender,
331        join_handle: JoinHandle<Result<Vec<u8>>>,
332        stream_load_http_timeout: Duration,
333    ) -> Self {
334        Self {
335            sender: Some(sender),
336            join_handle,
337            buffer: BytesMut::with_capacity(BUFFER_SIZE),
338            stream_load_http_timeout,
339        }
340    }
341
342    async fn send_chunk(&mut self) -> Result<()> {
343        if self.sender.is_none() {
344            return Ok(());
345        }
346
347        let chunk = mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE));
348
349        match self.sender.as_mut().unwrap().send(chunk.freeze()) {
350            Err(_e) => {
351                self.sender.take();
352                self.wait_handle().await?;
353
354                Err(SinkError::DorisStarrocksConnect(anyhow!("channel closed")))
355            }
356            _ => Ok(()),
357        }
358    }
359
360    pub async fn write(&mut self, data: Bytes) -> Result<()> {
361        self.buffer.put_slice(&data);
362        if self.buffer.len() >= MIN_CHUNK_SIZE {
363            self.send_chunk().await?;
364        }
365        Ok(())
366    }
367
368    async fn wait_handle(&mut self) -> Result<Vec<u8>> {
369        let res = match tokio::time::timeout(self.stream_load_http_timeout, &mut self.join_handle)
370            .await
371        {
372            Ok(res) => res.map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))??,
373            Err(err) => return Err(SinkError::DorisStarrocksConnect(anyhow!(err))),
374        };
375        Ok(res)
376    }
377
378    pub async fn finish(mut self) -> Result<Vec<u8>> {
379        if !self.buffer.is_empty() {
380            self.send_chunk().await?;
381        }
382        self.sender = None;
383        self.wait_handle().await
384    }
385}
386
387enum StreamLoadResponse {
388    BeRequest(Request),
389    HttpResponse(Response),
390}
391
392/// Send the request and handle redirection if any.
393/// The reason we handle the redirection manually is that if we let `reqwest` handle the redirection
394/// automatically, it will remove sensitive headers (such as Authorization) during the redirection,
395/// and there's no way to prevent this behavior.
396/// Please note, the FE address that user specified might be a FE follower not the leader, in this case,
397/// the follower FE will redirect request to leader FE and then to BE.
398async fn send_stream_load_request(
399    client: Client,
400    mut request: Request,
401    fe_host: &str,
402) -> Result<StreamLoadResponse> {
403    // possible redirection paths:
404    // RW <-> follower FE -> leader FE -> BE
405    // RW <-> leader FE -> BE
406    // RW <-> leader FE
407    for _ in 0..2 {
408        let original_http_port = request.url().port();
409        let mut request_for_redirection = request
410            .try_clone()
411            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't clone request")))?;
412        let resp = client.execute(request).await.map_err(|err| {
413            SinkError::DorisStarrocksConnect(
414                anyhow!(err).context("sending stream load request failed"),
415            )
416        })?;
417        let be_url = try_get_be_url(&resp, fe_host)?;
418        match be_url {
419            Some(be_url) => {
420                // we used an unconventional method to detect if we are currently redirecting to FE leader, i.e.,
421                // by comparing the port of the redirected url with that of the original request, if they are same, we consider
422                // this is a FE address. Because in practice, no one would deploy their `StarRocks` cluster with the same
423                // http port for both FE and BE. However, this is a potentially problematic assumption,
424                // we may investigate a better way to do this. For example, we could use the `show backends` command to check
425                // if the host of the redirected url is in the list. However, `show backends` requires
426                // the system-level privilege, which could break the backward compatibility.
427                let redirected_port = be_url.port();
428                *request_for_redirection.url_mut() = be_url;
429                if redirected_port == original_http_port {
430                    // redirected to FE, continue another round.
431                    request = request_for_redirection;
432                } else {
433                    // we got BE address here
434                    return Ok(StreamLoadResponse::BeRequest(request_for_redirection));
435                }
436            }
437            None => return Ok(StreamLoadResponse::HttpResponse(resp)),
438        }
439    }
440    Err(SinkError::DorisStarrocksConnect(anyhow!(
441        "redirection occur more than twice when sending stream load request"
442    )))
443}
444
445pub struct MetaRequestSender {
446    client: Client,
447    request: Request,
448    fe_host: String,
449}
450
451impl MetaRequestSender {
452    pub fn new(client: Client, request: Request, fe_host: String) -> Self {
453        Self {
454            client,
455            request,
456            fe_host,
457        }
458    }
459
460    pub async fn send(self) -> Result<Bytes> {
461        match send_stream_load_request(self.client.clone(), self.request, &self.fe_host).await? {
462            StreamLoadResponse::BeRequest(be_request) => self
463                .client
464                .execute(be_request)
465                .await
466                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
467                .bytes()
468                .await
469                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
470            StreamLoadResponse::HttpResponse(resp) => resp
471                .bytes()
472                .await
473                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
474        }
475    }
476}
477
478pub struct StarrocksTxnRequestBuilder {
479    url_begin: String,
480    url_load: String,
481    url_prepare: String,
482    url_commit: String,
483    url_rollback: String,
484    header: HashMap<String, String>,
485    fe_host: String,
486    stream_load_http_timeout: Duration,
487    // The `reqwest` crate suggests us reuse the Client, and we don't need make it Arc, because it
488    // already uses an Arc internally.
489    client: Client,
490}
491
492impl StarrocksTxnRequestBuilder {
493    pub fn new(
494        url: String,
495        header: HashMap<String, String>,
496        stream_load_http_timeout_ms: u64,
497    ) -> Result<Self> {
498        let fe_host = Url::parse(&url)
499            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
500            .host_str()
501            .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
502            .to_owned();
503
504        let url_begin = format!("{}/api/transaction/begin", url);
505        let url_load = format!("{}/api/transaction/load", url);
506        let url_prepare = format!("{}/api/transaction/prepare", url);
507        let url_commit = format!("{}/api/transaction/commit", url);
508        let url_rollback = format!("{}/api/transaction/rollback", url);
509
510        let stream_load_http_timeout = Duration::from_millis(stream_load_http_timeout_ms);
511
512        let client = Client::builder()
513            .pool_idle_timeout(POOL_IDLE_TIMEOUT)
514            .redirect(redirect::Policy::none())
515            .build()
516            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
517
518        Ok(Self {
519            url_begin,
520            url_load,
521            url_prepare,
522            url_commit,
523            url_rollback,
524            header,
525            fe_host,
526            stream_load_http_timeout,
527            client,
528        })
529    }
530
531    fn build_request(&self, uri: String, method: Method, label: String) -> Result<Request> {
532        let parsed_url =
533            Url::parse(&uri).map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
534        let mut request = Request::new(method, parsed_url);
535
536        if uri != self.url_load {
537            // Set timeout for non-load requests; load requests' timeout is controlled by `tokio::timeout`
538            *request.timeout_mut() = Some(self.stream_load_http_timeout);
539        }
540
541        let header = request.headers_mut();
542        for (k, v) in &self.header {
543            header.insert(
544                HeaderName::try_from(k)
545                    .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
546                HeaderValue::try_from(v)
547                    .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
548            );
549        }
550        header.insert(
551            "label",
552            HeaderValue::try_from(label)
553                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
554        );
555
556        Ok(request)
557    }
558
559    pub fn build_begin_request_sender(&self, label: String) -> Result<MetaRequestSender> {
560        let request = self.build_request(self.url_begin.clone(), Method::POST, label)?;
561        Ok(MetaRequestSender::new(
562            self.client.clone(),
563            request,
564            self.fe_host.clone(),
565        ))
566    }
567
568    pub fn build_prepare_request_sender(&self, label: String) -> Result<MetaRequestSender> {
569        let request = self.build_request(self.url_prepare.clone(), Method::POST, label)?;
570        Ok(MetaRequestSender::new(
571            self.client.clone(),
572            request,
573            self.fe_host.clone(),
574        ))
575    }
576
577    pub fn build_commit_request_sender(&self, label: String) -> Result<MetaRequestSender> {
578        let request = self.build_request(self.url_commit.clone(), Method::POST, label)?;
579        Ok(MetaRequestSender::new(
580            self.client.clone(),
581            request,
582            self.fe_host.clone(),
583        ))
584    }
585
586    pub fn build_rollback_request_sender(&self, label: String) -> Result<MetaRequestSender> {
587        let request = self.build_request(self.url_rollback.clone(), Method::POST, label)?;
588        Ok(MetaRequestSender::new(
589            self.client.clone(),
590            request,
591            self.fe_host.clone(),
592        ))
593    }
594
595    pub async fn build_txn_inserter(&self, label: String) -> Result<InserterInner> {
596        let request = self.build_request(self.url_load.clone(), Method::PUT, label.clone())?;
597        let mut be_request =
598            match send_stream_load_request(self.client.clone(), request, &self.fe_host).await? {
599                StreamLoadResponse::BeRequest(be_request) => be_request,
600                StreamLoadResponse::HttpResponse(resp) => {
601                    // If we get a response here, it should be from BE, so we extract the URL
602                    // and create a new request based on it.
603                    let url = resp.url().clone();
604                    self.build_request(url.into(), Method::PUT, label)?
605                }
606            };
607        let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
608        let body = Body::wrap_stream(
609            tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
610        );
611        *be_request.body_mut() = Some(body);
612
613        let client = self.client.clone();
614        let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
615            let response = client
616                .execute(be_request)
617                .await
618                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
619
620            let status = response.status();
621            let raw = response
622                .bytes()
623                .await
624                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
625                .into();
626
627            if status == StatusCode::OK {
628                Ok(raw)
629            } else {
630                let response_body = String::from_utf8(raw).map_err(|err| {
631                    SinkError::DorisStarrocksConnect(
632                        anyhow!(err).context("failed to parse response body"),
633                    )
634                })?;
635                Err(SinkError::DorisStarrocksConnect(anyhow!(
636                    "Failed connection {:?},{:?}",
637                    status,
638                    response_body
639                )))
640            }
641        });
642        Ok(InserterInner::new(
643            sender,
644            handle,
645            self.stream_load_http_timeout,
646        ))
647    }
648}