1use core::mem;
16use core::time::Duration;
17use std::collections::HashMap;
18use std::convert::Infallible;
19
20use anyhow::{Context, anyhow};
21use base64::Engine;
22use base64::engine::general_purpose;
23use bytes::{BufMut, Bytes, BytesMut};
24use futures::StreamExt;
25use reqwest::header::{HeaderName, HeaderValue};
26use reqwest::{Body, Client, Method, Request, RequestBuilder, Response, StatusCode, redirect};
27use tokio::sync::mpsc::UnboundedSender;
28use tokio::task::JoinHandle;
29use url::Url;
30
31use super::{Result, SinkError};
32
33const BUFFER_SIZE: usize = 64 * 1024;
34const MIN_CHUNK_SIZE: usize = BUFFER_SIZE - 1024;
35pub(crate) const DORIS_SUCCESS_STATUS: [&str; 2] = ["Success", "Publish Timeout"];
36pub(crate) const STARROCKS_SUCCESS_STATUS: [&str; 1] = ["OK"];
37pub(crate) const DORIS_DELETE_SIGN: &str = "__DORIS_DELETE_SIGN__";
38pub(crate) const STARROCKS_DELETE_SIGN: &str = "__op";
39
40pub(crate) const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
41const LOCALHOST: &str = "localhost";
42const LOCALHOST_IP: &str = "127.0.0.1";
43pub struct HeaderBuilder {
44 header: HashMap<String, String>,
45}
46impl Default for HeaderBuilder {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51impl HeaderBuilder {
52 pub fn new() -> Self {
53 Self {
54 header: HashMap::default(),
55 }
56 }
57
58 pub fn add_common_header(mut self) -> Self {
59 self.header
60 .insert("expect".to_owned(), "100-continue".to_owned());
61 self
62 }
63
64 pub fn set_label(mut self, label: String) -> Self {
67 self.header.insert("label".to_owned(), label);
68 self
69 }
70
71 pub fn set_columns_name(mut self, columns_name: Vec<&str>) -> Self {
72 let columns_name_str = columns_name.join(",");
73 self.header.insert("columns".to_owned(), columns_name_str);
74 self
75 }
76
77 pub fn add_hidden_column(mut self) -> Self {
79 self.header
80 .insert("hidden_columns".to_owned(), DORIS_DELETE_SIGN.to_owned());
81 self
82 }
83
84 pub fn enable_2_pc(mut self) -> Self {
87 self.header
88 .insert("two_phase_commit".to_owned(), "true".to_owned());
89 self
90 }
91
92 pub fn set_user_password(mut self, user: String, password: String) -> Self {
93 let auth = format!(
94 "Basic {}",
95 general_purpose::STANDARD.encode(format!("{}:{}", user, password))
96 );
97 self.header.insert("Authorization".to_owned(), auth);
98 self
99 }
100
101 pub fn set_txn_id(mut self, txn_id: i64) -> Self {
104 self.header
105 .insert("txn_operation".to_owned(), txn_id.to_string());
106 self
107 }
108
109 pub fn add_commit(mut self) -> Self {
112 self.header
113 .insert("txn_operation".to_owned(), "commit".to_owned());
114 self
115 }
116
117 pub fn add_abort(mut self) -> Self {
120 self.header
121 .insert("txn_operation".to_owned(), "abort".to_owned());
122 self
123 }
124
125 pub fn add_json_format(mut self) -> Self {
126 self.header.insert("format".to_owned(), "json".to_owned());
127 self
128 }
129
130 pub fn add_read_json_by_line(mut self) -> Self {
132 self.header
133 .insert("read_json_by_line".to_owned(), "true".to_owned());
134 self
135 }
136
137 pub fn add_strip_outer_array(mut self) -> Self {
139 self.header
140 .insert("strip_outer_array".to_owned(), "true".to_owned());
141 self
142 }
143
144 pub fn set_partial_update(mut self, partial_update: Option<String>) -> Self {
146 self.header.insert(
147 "partial_update".to_owned(),
148 partial_update.unwrap_or_else(|| "false".to_owned()),
149 );
150 self
151 }
152
153 pub fn set_partial_columns(mut self, partial_columns: Option<String>) -> Self {
155 self.header.insert(
156 "partial_columns".to_owned(),
157 partial_columns.unwrap_or_else(|| "false".to_owned()),
158 );
159 self
160 }
161
162 pub fn set_db(mut self, db: String) -> Self {
164 self.header.insert("db".to_owned(), db);
165 self
166 }
167
168 pub fn set_table(mut self, table: String) -> Self {
170 self.header.insert("table".to_owned(), table);
171 self
172 }
173
174 pub fn build(self) -> HashMap<String, String> {
175 self.header
176 }
177}
178
179fn try_get_be_url(resp: &Response, fe_host: &str) -> Result<Option<Url>> {
186 match resp.status() {
187 StatusCode::TEMPORARY_REDIRECT => {
188 let be_url = resp
189 .headers()
190 .get("location")
191 .ok_or_else(|| {
192 SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url in header",))
193 })?
194 .to_str()
195 .context("Can't get doris BE url in header")
196 .map_err(SinkError::DorisStarrocksConnect)?
197 .to_owned();
198
199 let mut parsed_be_url = Url::parse(&be_url)
200 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
201
202 if fe_host != LOCALHOST && fe_host != LOCALHOST_IP {
203 let be_host = parsed_be_url.host_str().ok_or_else(|| {
204 SinkError::DorisStarrocksConnect(anyhow!("Can't get be host from url"))
205 })?;
206
207 if be_host == LOCALHOST || be_host == LOCALHOST_IP {
208 parsed_be_url
211 .set_host(Some(fe_host))
212 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
213 }
214 }
215 Ok(Some(parsed_be_url))
216 }
217 StatusCode::OK => {
218 Ok(None)
222 }
223 _ => Err(SinkError::DorisStarrocksConnect(anyhow!(
224 "Can't get doris BE url",
225 ))),
226 }
227}
228
229pub struct InserterInnerBuilder {
230 url: String,
231 header: HashMap<String, String>,
232 #[expect(dead_code)]
233 sender: Option<Sender>,
234 fe_host: String,
235 stream_load_http_timeout: Duration,
236}
237impl InserterInnerBuilder {
238 pub fn new(
239 url: String,
240 db: String,
241 table: String,
242 header: HashMap<String, String>,
243 stream_load_http_timeout_ms: u64,
244 ) -> Result<Self> {
245 let fe_host = Url::parse(&url)
246 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
247 .host_str()
248 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
249 .to_owned();
250 let url = format!("{}/api/{}/{}/_stream_load", url, db, table);
251 let stream_load_http_timeout = Duration::from_millis(stream_load_http_timeout_ms);
252
253 Ok(Self {
254 url,
255 sender: None,
256 header,
257 fe_host,
258 stream_load_http_timeout,
259 })
260 }
261
262 fn build_request(&self, uri: String) -> Result<RequestBuilder> {
263 let client = Client::builder()
264 .pool_idle_timeout(POOL_IDLE_TIMEOUT)
265 .redirect(redirect::Policy::none()) .build()
267 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
268
269 let mut builder = client.put(uri);
270 for (k, v) in &self.header {
271 builder = builder.header(k, v);
272 }
273 Ok(builder)
274 }
275
276 pub async fn build(&self) -> Result<InserterInner> {
277 let builder = self.build_request(self.url.clone())?;
278 let resp = builder
279 .send()
280 .await
281 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
282
283 let be_url = try_get_be_url(&resp, self.fe_host.as_str())?
284 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url",)))?;
285
286 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
287 let body = Body::wrap_stream(
288 tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
289 );
290 let builder = self.build_request(be_url.into())?.body(body);
291
292 let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
293 let response = builder
294 .send()
295 .await
296 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
297 let status = response.status();
298 let raw = response
299 .bytes()
300 .await
301 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
302 .into();
303
304 if status == StatusCode::OK {
305 Ok(raw)
306 } else {
307 let response_body = String::from_utf8(raw).map_err(|err| {
308 SinkError::DorisStarrocksConnect(
309 anyhow!(err).context("failed to parse response body"),
310 )
311 })?;
312 Err(SinkError::DorisStarrocksConnect(anyhow!(
313 "Failed connection {:?},{:?}",
314 status,
315 response_body
316 )))
317 }
318 });
319 Ok(InserterInner::new(
320 sender,
321 handle,
322 self.stream_load_http_timeout,
323 ))
324 }
325}
326
327type Sender = UnboundedSender<Bytes>;
328
329pub struct InserterInner {
330 sender: Option<Sender>,
331 join_handle: JoinHandle<Result<Vec<u8>>>,
332 buffer: BytesMut,
333 stream_load_http_timeout: Duration,
334}
335impl InserterInner {
336 pub fn new(
337 sender: Sender,
338 join_handle: JoinHandle<Result<Vec<u8>>>,
339 stream_load_http_timeout: Duration,
340 ) -> Self {
341 Self {
342 sender: Some(sender),
343 join_handle,
344 buffer: BytesMut::with_capacity(BUFFER_SIZE),
345 stream_load_http_timeout,
346 }
347 }
348
349 async fn send_chunk(&mut self) -> Result<()> {
350 if self.sender.is_none() {
351 return Ok(());
352 }
353
354 let chunk = mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE));
355
356 match self.sender.as_mut().unwrap().send(chunk.freeze()) {
357 Err(_e) => {
358 self.sender.take();
359 self.wait_handle().await?;
360
361 Err(SinkError::DorisStarrocksConnect(anyhow!("channel closed")))
362 }
363 _ => Ok(()),
364 }
365 }
366
367 pub async fn write(&mut self, data: Bytes) -> Result<()> {
368 self.buffer.put_slice(&data);
369 if self.buffer.len() >= MIN_CHUNK_SIZE {
370 self.send_chunk().await?;
371 }
372 Ok(())
373 }
374
375 async fn wait_handle(&mut self) -> Result<Vec<u8>> {
376 let res = match tokio::time::timeout(self.stream_load_http_timeout, &mut self.join_handle)
377 .await
378 {
379 Ok(res) => res.map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))??,
380 Err(err) => return Err(SinkError::DorisStarrocksConnect(anyhow!(err))),
381 };
382 Ok(res)
383 }
384
385 pub async fn finish(mut self) -> Result<Vec<u8>> {
386 if !self.buffer.is_empty() {
387 self.send_chunk().await?;
388 }
389 self.sender = None;
390 self.wait_handle().await
391 }
392}
393
394enum StreamLoadResponse {
395 BeRequest(Request),
396 HttpResponse(Response),
397}
398
399async fn send_stream_load_request(
406 client: Client,
407 mut request: Request,
408 fe_host: &str,
409) -> Result<StreamLoadResponse> {
410 for _ in 0..2 {
415 let original_http_port = request.url().port();
416 let mut request_for_redirection = request
417 .try_clone()
418 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't clone request")))?;
419 let resp = client.execute(request).await.map_err(|err| {
420 SinkError::DorisStarrocksConnect(
421 anyhow!(err).context("sending stream load request failed"),
422 )
423 })?;
424 let be_url = try_get_be_url(&resp, fe_host)?;
425 match be_url {
426 Some(be_url) => {
427 let redirected_port = be_url.port();
435 *request_for_redirection.url_mut() = be_url;
436 if redirected_port == original_http_port {
437 request = request_for_redirection;
439 } else {
440 return Ok(StreamLoadResponse::BeRequest(request_for_redirection));
442 }
443 }
444 None => return Ok(StreamLoadResponse::HttpResponse(resp)),
445 }
446 }
447 Err(SinkError::DorisStarrocksConnect(anyhow!(
448 "redirection occur more than twice when sending stream load request"
449 )))
450}
451
452pub struct MetaRequestSender {
453 client: Client,
454 request: Request,
455 fe_host: String,
456}
457
458impl MetaRequestSender {
459 pub fn new(client: Client, request: Request, fe_host: String) -> Self {
460 Self {
461 client,
462 request,
463 fe_host,
464 }
465 }
466
467 pub async fn send(self) -> Result<Bytes> {
468 match send_stream_load_request(self.client.clone(), self.request, &self.fe_host).await? {
469 StreamLoadResponse::BeRequest(be_request) => self
470 .client
471 .execute(be_request)
472 .await
473 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
474 .bytes()
475 .await
476 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
477 StreamLoadResponse::HttpResponse(resp) => resp
478 .bytes()
479 .await
480 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
481 }
482 }
483}
484
485pub struct StarrocksTxnRequestBuilder {
486 url_begin: String,
487 url_load: String,
488 url_prepare: String,
489 url_commit: String,
490 url_rollback: String,
491 header: HashMap<String, String>,
492 fe_host: String,
493 stream_load_http_timeout: Duration,
494 client: Client,
497}
498
499impl StarrocksTxnRequestBuilder {
500 pub fn new(
501 url: String,
502 header: HashMap<String, String>,
503 stream_load_http_timeout_ms: u64,
504 ) -> Result<Self> {
505 let fe_host = Url::parse(&url)
506 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
507 .host_str()
508 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
509 .to_owned();
510
511 let url_begin = format!("{}/api/transaction/begin", url);
512 let url_load = format!("{}/api/transaction/load", url);
513 let url_prepare = format!("{}/api/transaction/prepare", url);
514 let url_commit = format!("{}/api/transaction/commit", url);
515 let url_rollback = format!("{}/api/transaction/rollback", url);
516
517 let stream_load_http_timeout = Duration::from_millis(stream_load_http_timeout_ms);
518
519 let client = Client::builder()
520 .pool_idle_timeout(POOL_IDLE_TIMEOUT)
521 .redirect(redirect::Policy::none())
522 .build()
523 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
524
525 Ok(Self {
526 url_begin,
527 url_load,
528 url_prepare,
529 url_commit,
530 url_rollback,
531 header,
532 fe_host,
533 stream_load_http_timeout,
534 client,
535 })
536 }
537
538 fn build_request(&self, uri: String, method: Method, label: String) -> Result<Request> {
539 let parsed_url =
540 Url::parse(&uri).map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
541 let mut request = Request::new(method, parsed_url);
542
543 if uri != self.url_load {
544 *request.timeout_mut() = Some(self.stream_load_http_timeout);
546 }
547
548 let header = request.headers_mut();
549 for (k, v) in &self.header {
550 header.insert(
551 HeaderName::try_from(k)
552 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
553 HeaderValue::try_from(v)
554 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
555 );
556 }
557 header.insert(
558 "label",
559 HeaderValue::try_from(label)
560 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
561 );
562
563 Ok(request)
564 }
565
566 pub fn build_begin_request_sender(&self, label: String) -> Result<MetaRequestSender> {
567 let request = self.build_request(self.url_begin.clone(), Method::POST, label)?;
568 Ok(MetaRequestSender::new(
569 self.client.clone(),
570 request,
571 self.fe_host.clone(),
572 ))
573 }
574
575 pub fn build_prepare_request_sender(&self, label: String) -> Result<MetaRequestSender> {
576 let request = self.build_request(self.url_prepare.clone(), Method::POST, label)?;
577 Ok(MetaRequestSender::new(
578 self.client.clone(),
579 request,
580 self.fe_host.clone(),
581 ))
582 }
583
584 pub fn build_commit_request_sender(&self, label: String) -> Result<MetaRequestSender> {
585 let request = self.build_request(self.url_commit.clone(), Method::POST, label)?;
586 Ok(MetaRequestSender::new(
587 self.client.clone(),
588 request,
589 self.fe_host.clone(),
590 ))
591 }
592
593 pub fn build_rollback_request_sender(&self, label: String) -> Result<MetaRequestSender> {
594 let request = self.build_request(self.url_rollback.clone(), Method::POST, label)?;
595 Ok(MetaRequestSender::new(
596 self.client.clone(),
597 request,
598 self.fe_host.clone(),
599 ))
600 }
601
602 pub async fn build_txn_inserter(&self, label: String) -> Result<InserterInner> {
603 let request = self.build_request(self.url_load.clone(), Method::PUT, label.clone())?;
604 let mut be_request =
605 match send_stream_load_request(self.client.clone(), request, &self.fe_host).await? {
606 StreamLoadResponse::BeRequest(be_request) => be_request,
607 StreamLoadResponse::HttpResponse(resp) => {
608 let url = resp.url().clone();
611 self.build_request(url.into(), Method::PUT, label)?
612 }
613 };
614 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
615 let body = Body::wrap_stream(
616 tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
617 );
618 *be_request.body_mut() = Some(body);
619
620 let client = self.client.clone();
621 let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
622 let response = client
623 .execute(be_request)
624 .await
625 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
626
627 let status = response.status();
628 let raw = response
629 .bytes()
630 .await
631 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
632 .into();
633
634 if status == StatusCode::OK {
635 Ok(raw)
636 } else {
637 let response_body = String::from_utf8(raw).map_err(|err| {
638 SinkError::DorisStarrocksConnect(
639 anyhow!(err).context("failed to parse response body"),
640 )
641 })?;
642 Err(SinkError::DorisStarrocksConnect(anyhow!(
643 "Failed connection {:?},{:?}",
644 status,
645 response_body
646 )))
647 }
648 });
649 Ok(InserterInner::new(
650 sender,
651 handle,
652 self.stream_load_http_timeout,
653 ))
654 }
655}