1use core::mem;
16use core::time::Duration;
17use std::collections::HashMap;
18use std::convert::Infallible;
19
20use anyhow::{Context, anyhow};
21use base64::Engine;
22use base64::engine::general_purpose;
23use bytes::{BufMut, Bytes, BytesMut};
24use futures::StreamExt;
25use reqwest::header::{HeaderName, HeaderValue};
26use reqwest::{Body, Client, Method, Request, RequestBuilder, Response, StatusCode, redirect};
27use tokio::sync::mpsc::UnboundedSender;
28use tokio::task::JoinHandle;
29use url::Url;
30
31use super::{Result, SinkError};
32
33const BUFFER_SIZE: usize = 64 * 1024;
34const MIN_CHUNK_SIZE: usize = BUFFER_SIZE - 1024;
35pub(crate) const DORIS_SUCCESS_STATUS: [&str; 2] = ["Success", "Publish Timeout"];
36pub(crate) const STARROCKS_SUCCESS_STATUS: [&str; 1] = ["OK"];
37pub(crate) const DORIS_DELETE_SIGN: &str = "__DORIS_DELETE_SIGN__";
38pub(crate) const STARROCKS_DELETE_SIGN: &str = "__op";
39
40const WAIT_HANDDLE_TIMEOUT: Duration = Duration::from_secs(10);
41pub(crate) const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
42const LOCALHOST: &str = "localhost";
43const LOCALHOST_IP: &str = "127.0.0.1";
44pub struct HeaderBuilder {
45 header: HashMap<String, String>,
46}
47impl Default for HeaderBuilder {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52impl HeaderBuilder {
53 pub fn new() -> Self {
54 Self {
55 header: HashMap::default(),
56 }
57 }
58
59 pub fn add_common_header(mut self) -> Self {
60 self.header
61 .insert("expect".to_owned(), "100-continue".to_owned());
62 self
63 }
64
65 pub fn set_label(mut self, label: String) -> Self {
68 self.header.insert("label".to_owned(), label);
69 self
70 }
71
72 pub fn set_columns_name(mut self, columns_name: Vec<&str>) -> Self {
73 let columns_name_str = columns_name.join(",");
74 self.header.insert("columns".to_owned(), columns_name_str);
75 self
76 }
77
78 pub fn add_hidden_column(mut self) -> Self {
80 self.header
81 .insert("hidden_columns".to_owned(), DORIS_DELETE_SIGN.to_owned());
82 self
83 }
84
85 pub fn enable_2_pc(mut self) -> Self {
88 self.header
89 .insert("two_phase_commit".to_owned(), "true".to_owned());
90 self
91 }
92
93 pub fn set_user_password(mut self, user: String, password: String) -> Self {
94 let auth = format!(
95 "Basic {}",
96 general_purpose::STANDARD.encode(format!("{}:{}", user, password))
97 );
98 self.header.insert("Authorization".to_owned(), auth);
99 self
100 }
101
102 pub fn set_txn_id(mut self, txn_id: i64) -> Self {
105 self.header
106 .insert("txn_operation".to_owned(), txn_id.to_string());
107 self
108 }
109
110 pub fn add_commit(mut self) -> Self {
113 self.header
114 .insert("txn_operation".to_owned(), "commit".to_owned());
115 self
116 }
117
118 pub fn add_abort(mut self) -> Self {
121 self.header
122 .insert("txn_operation".to_owned(), "abort".to_owned());
123 self
124 }
125
126 pub fn add_json_format(mut self) -> Self {
127 self.header.insert("format".to_owned(), "json".to_owned());
128 self
129 }
130
131 pub fn add_read_json_by_line(mut self) -> Self {
133 self.header
134 .insert("read_json_by_line".to_owned(), "true".to_owned());
135 self
136 }
137
138 pub fn add_strip_outer_array(mut self) -> Self {
140 self.header
141 .insert("strip_outer_array".to_owned(), "true".to_owned());
142 self
143 }
144
145 pub fn set_partial_update(mut self, partial_update: Option<String>) -> Self {
147 self.header.insert(
148 "partial_update".to_owned(),
149 partial_update.unwrap_or_else(|| "false".to_owned()),
150 );
151 self
152 }
153
154 pub fn set_partial_columns(mut self, partial_columns: Option<String>) -> Self {
156 self.header.insert(
157 "partial_columns".to_owned(),
158 partial_columns.unwrap_or_else(|| "false".to_owned()),
159 );
160 self
161 }
162
163 pub fn set_db(mut self, db: String) -> Self {
165 self.header.insert("db".to_owned(), db);
166 self
167 }
168
169 pub fn set_table(mut self, table: String) -> Self {
171 self.header.insert("table".to_owned(), table);
172 self
173 }
174
175 pub fn build(self) -> HashMap<String, String> {
176 self.header
177 }
178}
179
180fn try_get_be_url(resp: &Response, fe_host: &str) -> Result<Option<Url>> {
187 match resp.status() {
188 StatusCode::TEMPORARY_REDIRECT => {
189 let be_url = resp
190 .headers()
191 .get("location")
192 .ok_or_else(|| {
193 SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url in header",))
194 })?
195 .to_str()
196 .context("Can't get doris BE url in header")
197 .map_err(SinkError::DorisStarrocksConnect)?
198 .to_owned();
199
200 let mut parsed_be_url = Url::parse(&be_url)
201 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
202
203 if fe_host != LOCALHOST && fe_host != LOCALHOST_IP {
204 let be_host = parsed_be_url.host_str().ok_or_else(|| {
205 SinkError::DorisStarrocksConnect(anyhow!("Can't get be host from url"))
206 })?;
207
208 if be_host == LOCALHOST || be_host == LOCALHOST_IP {
209 parsed_be_url
212 .set_host(Some(fe_host))
213 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
214 }
215 }
216 Ok(Some(parsed_be_url))
217 }
218 StatusCode::OK => {
219 Ok(None)
223 }
224 _ => Err(SinkError::DorisStarrocksConnect(anyhow!(
225 "Can't get doris BE url",
226 ))),
227 }
228}
229
230pub struct InserterInnerBuilder {
231 url: String,
232 header: HashMap<String, String>,
233 #[expect(dead_code)]
234 sender: Option<Sender>,
235 fe_host: String,
236}
237impl InserterInnerBuilder {
238 pub fn new(
239 url: String,
240 db: String,
241 table: String,
242 header: HashMap<String, String>,
243 ) -> Result<Self> {
244 let fe_host = Url::parse(&url)
245 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
246 .host_str()
247 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
248 .to_owned();
249 let url = format!("{}/api/{}/{}/_stream_load", url, db, table);
250
251 Ok(Self {
252 url,
253 sender: None,
254 header,
255 fe_host,
256 })
257 }
258
259 fn build_request(&self, uri: String) -> Result<RequestBuilder> {
260 let client = Client::builder()
261 .pool_idle_timeout(POOL_IDLE_TIMEOUT)
262 .redirect(redirect::Policy::none()) .build()
264 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
265
266 let mut builder = client.put(uri);
267 for (k, v) in &self.header {
268 builder = builder.header(k, v);
269 }
270 Ok(builder)
271 }
272
273 pub async fn build(&self) -> Result<InserterInner> {
274 let builder = self.build_request(self.url.clone())?;
275 let resp = builder
276 .send()
277 .await
278 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
279
280 let be_url = try_get_be_url(&resp, self.fe_host.as_str())?
281 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get doris BE url",)))?;
282
283 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
284 let body = Body::wrap_stream(
285 tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
286 );
287 let builder = self.build_request(be_url.into())?.body(body);
288
289 let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
290 let response = builder
291 .send()
292 .await
293 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
294 let status = response.status();
295 let raw = response
296 .bytes()
297 .await
298 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
299 .into();
300
301 if status == StatusCode::OK {
302 Ok(raw)
303 } else {
304 let response_body = String::from_utf8(raw).map_err(|err| {
305 SinkError::DorisStarrocksConnect(
306 anyhow!(err).context("failed to parse response body"),
307 )
308 })?;
309 Err(SinkError::DorisStarrocksConnect(anyhow!(
310 "Failed connection {:?},{:?}",
311 status,
312 response_body
313 )))
314 }
315 });
316 Ok(InserterInner::new(sender, handle, WAIT_HANDDLE_TIMEOUT))
317 }
318}
319
320type Sender = UnboundedSender<Bytes>;
321
322pub struct InserterInner {
323 sender: Option<Sender>,
324 join_handle: JoinHandle<Result<Vec<u8>>>,
325 buffer: BytesMut,
326 stream_load_http_timeout: Duration,
327}
328impl InserterInner {
329 pub fn new(
330 sender: Sender,
331 join_handle: JoinHandle<Result<Vec<u8>>>,
332 stream_load_http_timeout: Duration,
333 ) -> Self {
334 Self {
335 sender: Some(sender),
336 join_handle,
337 buffer: BytesMut::with_capacity(BUFFER_SIZE),
338 stream_load_http_timeout,
339 }
340 }
341
342 async fn send_chunk(&mut self) -> Result<()> {
343 if self.sender.is_none() {
344 return Ok(());
345 }
346
347 let chunk = mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE));
348
349 match self.sender.as_mut().unwrap().send(chunk.freeze()) {
350 Err(_e) => {
351 self.sender.take();
352 self.wait_handle().await?;
353
354 Err(SinkError::DorisStarrocksConnect(anyhow!("channel closed")))
355 }
356 _ => Ok(()),
357 }
358 }
359
360 pub async fn write(&mut self, data: Bytes) -> Result<()> {
361 self.buffer.put_slice(&data);
362 if self.buffer.len() >= MIN_CHUNK_SIZE {
363 self.send_chunk().await?;
364 }
365 Ok(())
366 }
367
368 async fn wait_handle(&mut self) -> Result<Vec<u8>> {
369 let res = match tokio::time::timeout(self.stream_load_http_timeout, &mut self.join_handle)
370 .await
371 {
372 Ok(res) => res.map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))??,
373 Err(err) => return Err(SinkError::DorisStarrocksConnect(anyhow!(err))),
374 };
375 Ok(res)
376 }
377
378 pub async fn finish(mut self) -> Result<Vec<u8>> {
379 if !self.buffer.is_empty() {
380 self.send_chunk().await?;
381 }
382 self.sender = None;
383 self.wait_handle().await
384 }
385}
386
387enum StreamLoadResponse {
388 BeRequest(Request),
389 HttpResponse(Response),
390}
391
392async fn send_stream_load_request(
399 client: Client,
400 mut request: Request,
401 fe_host: &str,
402) -> Result<StreamLoadResponse> {
403 for _ in 0..2 {
408 let original_http_port = request.url().port();
409 let mut request_for_redirection = request
410 .try_clone()
411 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't clone request")))?;
412 let resp = client.execute(request).await.map_err(|err| {
413 SinkError::DorisStarrocksConnect(
414 anyhow!(err).context("sending stream load request failed"),
415 )
416 })?;
417 let be_url = try_get_be_url(&resp, fe_host)?;
418 match be_url {
419 Some(be_url) => {
420 let redirected_port = be_url.port();
428 *request_for_redirection.url_mut() = be_url;
429 if redirected_port == original_http_port {
430 request = request_for_redirection;
432 } else {
433 return Ok(StreamLoadResponse::BeRequest(request_for_redirection));
435 }
436 }
437 None => return Ok(StreamLoadResponse::HttpResponse(resp)),
438 }
439 }
440 Err(SinkError::DorisStarrocksConnect(anyhow!(
441 "redirection occur more than twice when sending stream load request"
442 )))
443}
444
445pub struct MetaRequestSender {
446 client: Client,
447 request: Request,
448 fe_host: String,
449}
450
451impl MetaRequestSender {
452 pub fn new(client: Client, request: Request, fe_host: String) -> Self {
453 Self {
454 client,
455 request,
456 fe_host,
457 }
458 }
459
460 pub async fn send(self) -> Result<Bytes> {
461 match send_stream_load_request(self.client.clone(), self.request, &self.fe_host).await? {
462 StreamLoadResponse::BeRequest(be_request) => self
463 .client
464 .execute(be_request)
465 .await
466 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
467 .bytes()
468 .await
469 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
470 StreamLoadResponse::HttpResponse(resp) => resp
471 .bytes()
472 .await
473 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err))),
474 }
475 }
476}
477
478pub struct StarrocksTxnRequestBuilder {
479 url_begin: String,
480 url_load: String,
481 url_prepare: String,
482 url_commit: String,
483 url_rollback: String,
484 header: HashMap<String, String>,
485 fe_host: String,
486 stream_load_http_timeout: Duration,
487 client: Client,
490}
491
492impl StarrocksTxnRequestBuilder {
493 pub fn new(
494 url: String,
495 header: HashMap<String, String>,
496 stream_load_http_timeout_ms: u64,
497 ) -> Result<Self> {
498 let fe_host = Url::parse(&url)
499 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
500 .host_str()
501 .ok_or_else(|| SinkError::DorisStarrocksConnect(anyhow!("Can't get fe host from url")))?
502 .to_owned();
503
504 let url_begin = format!("{}/api/transaction/begin", url);
505 let url_load = format!("{}/api/transaction/load", url);
506 let url_prepare = format!("{}/api/transaction/prepare", url);
507 let url_commit = format!("{}/api/transaction/commit", url);
508 let url_rollback = format!("{}/api/transaction/rollback", url);
509
510 let stream_load_http_timeout = Duration::from_millis(stream_load_http_timeout_ms);
511
512 let client = Client::builder()
513 .pool_idle_timeout(POOL_IDLE_TIMEOUT)
514 .redirect(redirect::Policy::none())
515 .build()
516 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
517
518 Ok(Self {
519 url_begin,
520 url_load,
521 url_prepare,
522 url_commit,
523 url_rollback,
524 header,
525 fe_host,
526 stream_load_http_timeout,
527 client,
528 })
529 }
530
531 fn build_request(&self, uri: String, method: Method, label: String) -> Result<Request> {
532 let parsed_url =
533 Url::parse(&uri).map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
534 let mut request = Request::new(method, parsed_url);
535
536 if uri != self.url_load {
537 *request.timeout_mut() = Some(self.stream_load_http_timeout);
539 }
540
541 let header = request.headers_mut();
542 for (k, v) in &self.header {
543 header.insert(
544 HeaderName::try_from(k)
545 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
546 HeaderValue::try_from(v)
547 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
548 );
549 }
550 header.insert(
551 "label",
552 HeaderValue::try_from(label)
553 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
554 );
555
556 Ok(request)
557 }
558
559 pub fn build_begin_request_sender(&self, label: String) -> Result<MetaRequestSender> {
560 let request = self.build_request(self.url_begin.clone(), Method::POST, label)?;
561 Ok(MetaRequestSender::new(
562 self.client.clone(),
563 request,
564 self.fe_host.clone(),
565 ))
566 }
567
568 pub fn build_prepare_request_sender(&self, label: String) -> Result<MetaRequestSender> {
569 let request = self.build_request(self.url_prepare.clone(), Method::POST, label)?;
570 Ok(MetaRequestSender::new(
571 self.client.clone(),
572 request,
573 self.fe_host.clone(),
574 ))
575 }
576
577 pub fn build_commit_request_sender(&self, label: String) -> Result<MetaRequestSender> {
578 let request = self.build_request(self.url_commit.clone(), Method::POST, label)?;
579 Ok(MetaRequestSender::new(
580 self.client.clone(),
581 request,
582 self.fe_host.clone(),
583 ))
584 }
585
586 pub fn build_rollback_request_sender(&self, label: String) -> Result<MetaRequestSender> {
587 let request = self.build_request(self.url_rollback.clone(), Method::POST, label)?;
588 Ok(MetaRequestSender::new(
589 self.client.clone(),
590 request,
591 self.fe_host.clone(),
592 ))
593 }
594
595 pub async fn build_txn_inserter(&self, label: String) -> Result<InserterInner> {
596 let request = self.build_request(self.url_load.clone(), Method::PUT, label.clone())?;
597 let mut be_request =
598 match send_stream_load_request(self.client.clone(), request, &self.fe_host).await? {
599 StreamLoadResponse::BeRequest(be_request) => be_request,
600 StreamLoadResponse::HttpResponse(resp) => {
601 let url = resp.url().clone();
604 self.build_request(url.into(), Method::PUT, label)?
605 }
606 };
607 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
608 let body = Body::wrap_stream(
609 tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
610 );
611 *be_request.body_mut() = Some(body);
612
613 let client = self.client.clone();
614 let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
615 let response = client
616 .execute(be_request)
617 .await
618 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
619
620 let status = response.status();
621 let raw = response
622 .bytes()
623 .await
624 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
625 .into();
626
627 if status == StatusCode::OK {
628 Ok(raw)
629 } else {
630 let response_body = String::from_utf8(raw).map_err(|err| {
631 SinkError::DorisStarrocksConnect(
632 anyhow!(err).context("failed to parse response body"),
633 )
634 })?;
635 Err(SinkError::DorisStarrocksConnect(anyhow!(
636 "Failed connection {:?},{:?}",
637 status,
638 response_body
639 )))
640 }
641 });
642 Ok(InserterInner::new(
643 sender,
644 handle,
645 self.stream_load_http_timeout,
646 ))
647 }
648}