1use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use serde_with::{DisplayFromStr, serde_as};
30use thiserror_ext::AsReport;
31use url::form_urlencoded;
32use risingwave_pb::id::ExecutorId;
33use with_options::WithOptions;
34
35use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
36use super::doris_starrocks_connector::{
37 HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
38 StarrocksTxnRequestBuilder,
39};
40use super::encoder::{JsonEncoder, RowEncoder};
41use super::{
42 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
43 SinkWriterMetrics,
44};
45use crate::enforce_secret::EnforceSecret;
46use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
47use crate::sink::writer::SinkWriter;
48use crate::sink::{Result, Sink, SinkWriterParam};
49
50pub const STARROCKS_SINK: &str = "starrocks";
51const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
52const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
53const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
54
55pub const fn _default_stream_load_http_timeout_ms() -> u64 {
56 30 * 1000
57}
58
59const fn default_use_https() -> bool {
60 false
61}
62
63#[serde_as]
64#[derive(Deserialize, Debug, Clone, WithOptions)]
65pub struct StarrocksCommon {
66 #[serde(rename = "starrocks.host")]
68 pub host: String,
69 #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
71 pub mysql_port: String,
72 #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
74 pub http_port: String,
75 #[serde(rename = "starrocks.user")]
77 pub user: String,
78 #[serde(rename = "starrocks.password")]
80 pub password: String,
81 #[serde(rename = "starrocks.database")]
83 pub database: String,
84 #[serde(rename = "starrocks.table")]
86 pub table: String,
87
88 #[serde(rename = "starrocks.use_https")]
90 #[serde(default = "default_use_https")]
91 #[serde_as(as = "DisplayFromStr")]
92 pub use_https: bool,
93}
94
95impl EnforceSecret for StarrocksCommon {
96 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
97 "starrocks.password", "starrocks.user"
98 };
99}
100
101#[serde_as]
102#[derive(Clone, Debug, Deserialize, WithOptions)]
103pub struct StarrocksConfig {
104 #[serde(flatten)]
105 pub common: StarrocksCommon,
106
107 #[serde(
109 rename = "starrocks.stream_load.http.timeout.ms",
110 default = "_default_stream_load_http_timeout_ms"
111 )]
112 #[serde_as(as = "DisplayFromStr")]
113 #[with_option(allow_alter_on_fly)]
114 pub stream_load_http_timeout_ms: u64,
115
116 #[serde(default = "default_commit_checkpoint_interval")]
122 #[serde_as(as = "DisplayFromStr")]
123 #[with_option(allow_alter_on_fly)]
124 pub commit_checkpoint_interval: u64,
125
126 #[serde(rename = "starrocks.partial_update")]
128 pub partial_update: Option<String>,
129
130 pub r#type: String, }
132
133impl EnforceSecret for StarrocksConfig {
134 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
135 StarrocksCommon::enforce_one(prop)
136 }
137}
138
139fn default_commit_checkpoint_interval() -> u64 {
140 DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
141}
142
143impl StarrocksConfig {
144 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
145 let config =
146 serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
147 .map_err(|e| SinkError::Config(anyhow!(e)))?;
148 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
149 return Err(SinkError::Config(anyhow!(
150 "`{}` must be {}, or {}",
151 SINK_TYPE_OPTION,
152 SINK_TYPE_APPEND_ONLY,
153 SINK_TYPE_UPSERT
154 )));
155 }
156 if config.commit_checkpoint_interval == 0 {
157 return Err(SinkError::Config(anyhow!(
158 "`commit_checkpoint_interval` must be greater than 0"
159 )));
160 }
161 Ok(config)
162 }
163}
164
165#[derive(Debug)]
166pub struct StarrocksSink {
167 pub config: StarrocksConfig,
168 schema: Schema,
169 pk_indices: Vec<usize>,
170 is_append_only: bool,
171}
172
173impl EnforceSecret for StarrocksSink {
174 fn enforce_secret<'a>(
175 prop_iter: impl Iterator<Item = &'a str>,
176 ) -> crate::error::ConnectorResult<()> {
177 for prop in prop_iter {
178 StarrocksConfig::enforce_one(prop)?;
179 }
180 Ok(())
181 }
182}
183
184impl StarrocksSink {
185 pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
186 let pk_indices = param.downstream_pk_or_empty();
187 let is_append_only = param.sink_type.is_append_only();
188 Ok(Self {
189 config,
190 schema,
191 pk_indices,
192 is_append_only,
193 })
194 }
195}
196
197impl StarrocksSink {
198 fn check_column_name_and_type(
199 &self,
200 starrocks_columns_desc: HashMap<String, String>,
201 ) -> Result<()> {
202 let rw_fields_name = self.schema.fields();
203 if rw_fields_name.len() > starrocks_columns_desc.len() {
204 return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
205 }
206
207 for i in rw_fields_name {
208 let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
209 SinkError::Starrocks(format!(
210 "Column name don't find in starrocks, risingwave is {:?} ",
211 i.name
212 ))
213 })?;
214 if !Self::check_and_correct_column_type(&i.data_type, value)? {
215 return Err(SinkError::Starrocks(format!(
216 "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
217 i.name, value, i.data_type
218 )));
219 }
220 }
221 Ok(())
222 }
223
224 fn check_and_correct_column_type(
225 rw_data_type: &DataType,
226 starrocks_data_type: &String,
227 ) -> Result<bool> {
228 match rw_data_type {
229 risingwave_common::types::DataType::Boolean => {
230 Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
231 }
232 risingwave_common::types::DataType::Int16 => {
233 Ok(starrocks_data_type.contains("smallint"))
234 }
235 risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
236 risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
237 risingwave_common::types::DataType::Float32 => {
238 Ok(starrocks_data_type.contains("float"))
239 }
240 risingwave_common::types::DataType::Float64 => {
241 Ok(starrocks_data_type.contains("double"))
242 }
243 risingwave_common::types::DataType::Decimal => {
244 Ok(starrocks_data_type.contains("decimal"))
245 }
246 risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
247 risingwave_common::types::DataType::Varchar => {
248 Ok(starrocks_data_type.contains("varchar"))
249 }
250 risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
251 "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
252 )),
253 risingwave_common::types::DataType::Timestamp => {
254 Ok(starrocks_data_type.contains("datetime"))
255 }
256 risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
257 "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
258 )),
259 risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
260 "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
261 )),
262 risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
263 "STRUCT is not supported for Starrocks sink.".to_owned(),
264 )),
265 risingwave_common::types::DataType::List(list) => {
266 if starrocks_data_type.contains("unknown") {
268 return Ok(true);
269 }
270 let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
271 Ok(check_result && starrocks_data_type.contains("array"))
272 }
273 risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
274 "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
275 )),
276 risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
277 risingwave_common::types::DataType::Serial => {
278 Ok(starrocks_data_type.contains("bigint"))
279 }
280 risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
281 "INT256 is not supported for Starrocks sink.".to_owned(),
282 )),
283 risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
284 "MAP is not supported for Starrocks sink.".to_owned(),
285 )),
286 DataType::Vector(_) => Err(SinkError::Starrocks(
287 "VECTOR is not supported for Starrocks sink.".to_owned(),
288 )),
289 }
290 }
291}
292
293impl Sink for StarrocksSink {
294 type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
295
296 const SINK_NAME: &'static str = STARROCKS_SINK;
297
298 async fn validate(&self) -> Result<()> {
299 if !self.is_append_only && self.pk_indices.is_empty() {
300 return Err(SinkError::Config(anyhow!(
301 "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
302 )));
303 }
304 let mut client = StarrocksSchemaClient::new(
306 self.config.common.host.clone(),
307 self.config.common.mysql_port.clone(),
308 self.config.common.table.clone(),
309 self.config.common.database.clone(),
310 self.config.common.user.clone(),
311 self.config.common.password.clone(),
312 )
313 .await?;
314 let (read_model, pks) = client.get_pk_from_starrocks().await?;
315
316 if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
317 return Err(SinkError::Config(anyhow!(
318 "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
319 )));
320 }
321
322 for (index, filed) in self.schema.fields().iter().enumerate() {
323 if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
324 return Err(SinkError::Starrocks(format!(
325 "Can't find pk {:?} in starrocks",
326 filed.name
327 )));
328 }
329 }
330
331 let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
332
333 self.check_column_name_and_type(starrocks_columns_desc)?;
334 Ok(())
335 }
336
337 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
338 StarrocksConfig::from_btreemap(config.clone())?;
339 Ok(())
340 }
341
342 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
343 let commit_checkpoint_interval =
344 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
345 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
346 );
347
348 let writer = StarrocksSinkWriter::new(
349 self.config.clone(),
350 self.schema.clone(),
351 self.pk_indices.clone(),
352 self.is_append_only,
353 writer_param.executor_id,
354 )?;
355
356 let metrics = SinkWriterMetrics::new(&writer_param);
357
358 Ok(DecoupleCheckpointLogSinkerOf::new(
359 writer,
360 metrics,
361 commit_checkpoint_interval,
362 ))
363 }
364}
365
366pub struct StarrocksSinkWriter {
367 pub config: StarrocksConfig,
368 #[expect(dead_code)]
369 schema: Schema,
370 #[expect(dead_code)]
371 pk_indices: Vec<usize>,
372 is_append_only: bool,
373 client: Option<StarrocksClient>,
374 txn_client: Arc<StarrocksTxnClient>,
375 row_encoder: JsonEncoder,
376 executor_id: ExecutorId,
377 curr_txn_label: Option<String>,
378}
379
380impl TryFrom<SinkParam> for StarrocksSink {
381 type Error = SinkError;
382
383 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
384 let schema = param.schema();
385 let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
386 StarrocksSink::new(param, config, schema)
387 }
388}
389
390impl StarrocksSinkWriter {
391 pub fn new(
392 config: StarrocksConfig,
393 schema: Schema,
394 pk_indices: Vec<usize>,
395 is_append_only: bool,
396 executor_id: ExecutorId,
397 ) -> Result<Self> {
398 let mut field_names = schema.names_str();
399 if !is_append_only {
400 field_names.push(STARROCKS_DELETE_SIGN);
401 };
402 let field_names = field_names
405 .into_iter()
406 .map(|name| format!("`{}`", name))
407 .collect::<Vec<String>>();
408 let field_names_str = field_names
409 .iter()
410 .map(|name| name.as_str())
411 .collect::<Vec<&str>>();
412
413 let header = HeaderBuilder::new()
414 .add_common_header()
415 .set_user_password(config.common.user.clone(), config.common.password.clone())
416 .add_json_format()
417 .set_partial_update(config.partial_update.clone())
418 .set_columns_name(field_names_str)
419 .set_db(config.common.database.clone())
420 .set_table(config.common.table.clone())
421 .build();
422
423 let url = if config.common.use_https {
424 format!("https://{}:{}", config.common.host, config.common.http_port)
425 } else {
426 format!("http://{}:{}", config.common.host, config.common.http_port)
427 };
428 let txn_request_builder =
429 StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
430
431 Ok(Self {
432 config,
433 schema: schema.clone(),
434 pk_indices,
435 is_append_only,
436 client: None,
437 txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
438 row_encoder: JsonEncoder::new_with_starrocks(schema, None),
439 executor_id,
440 curr_txn_label: None,
441 })
442 }
443
444 async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
445 for (op, row) in chunk.rows() {
446 if op != Op::Insert {
447 continue;
448 }
449 let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
450 self.client
451 .as_mut()
452 .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
453 .write(row_json_string.into())
454 .await?;
455 }
456 Ok(())
457 }
458
459 async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
460 for (op, row) in chunk.rows() {
461 match op {
462 Op::Insert => {
463 let mut row_json_value = self.row_encoder.encode(row)?;
464 row_json_value.insert(
465 STARROCKS_DELETE_SIGN.to_owned(),
466 Value::String("0".to_owned()),
467 );
468 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
469 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
470 })?;
471 self.client
472 .as_mut()
473 .ok_or_else(|| {
474 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
475 })?
476 .write(row_json_string.into())
477 .await?;
478 }
479 Op::Delete => {
480 let mut row_json_value = self.row_encoder.encode(row)?;
481 row_json_value.insert(
482 STARROCKS_DELETE_SIGN.to_owned(),
483 Value::String("1".to_owned()),
484 );
485 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
486 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
487 })?;
488 self.client
489 .as_mut()
490 .ok_or_else(|| {
491 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
492 })?
493 .write(row_json_string.into())
494 .await?;
495 }
496 Op::UpdateDelete => {}
497 Op::UpdateInsert => {
498 let mut row_json_value = self.row_encoder.encode(row)?;
499 row_json_value.insert(
500 STARROCKS_DELETE_SIGN.to_owned(),
501 Value::String("0".to_owned()),
502 );
503 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
504 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
505 })?;
506 self.client
507 .as_mut()
508 .ok_or_else(|| {
509 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
510 })?
511 .write(row_json_string.into())
512 .await?;
513 }
514 }
515 }
516 Ok(())
517 }
518
519 #[inline(always)]
521 fn new_txn_label(&self) -> String {
522 format!(
523 "rw-txn-{}-{}",
524 self.executor_id,
525 chrono::Utc::now().timestamp_micros()
526 )
527 }
528
529 async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
530 tracing::debug!(?txn_label, "prepare transaction");
531 let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
532 if txn_label != txn_label_res {
533 return Err(SinkError::Starrocks(format!(
534 "label {} returned from prepare transaction {} differs from the current one",
535 txn_label, txn_label_res
536 )));
537 }
538 tracing::debug!(?txn_label, "commit transaction");
539 let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
540 if txn_label != txn_label_res {
541 return Err(SinkError::Starrocks(format!(
542 "label {} returned from commit transaction {} differs from the current one",
543 txn_label, txn_label_res
544 )));
545 }
546 Ok(())
547 }
548}
549
550impl Drop for StarrocksSinkWriter {
551 fn drop(&mut self) {
552 if let Some(txn_label) = self.curr_txn_label.take() {
553 let txn_client = self.txn_client.clone();
554 tokio::spawn(async move {
555 if let Err(e) = txn_client.rollback(txn_label.clone()).await {
556 tracing::error!(
557 "starrocks rollback transaction error: {:?}, txn label: {}",
558 e.as_report(),
559 txn_label
560 );
561 }
562 });
563 }
564 }
565}
566
567#[async_trait]
568impl SinkWriter for StarrocksSinkWriter {
569 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
570 Ok(())
571 }
572
573 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
574 if self.curr_txn_label.is_none() {
578 let txn_label = self.new_txn_label();
579 tracing::debug!(?txn_label, "begin transaction");
580 let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
581 if txn_label != txn_label_res {
582 return Err(SinkError::Starrocks(format!(
583 "label {} returned from StarRocks {} differs from generated one",
584 txn_label, txn_label_res
585 )));
586 }
587 self.curr_txn_label = Some(txn_label.clone());
588 }
589 if self.client.is_none() {
590 let txn_label = self.curr_txn_label.clone();
591 self.client = Some(StarrocksClient::new(
592 self.txn_client.load(txn_label.unwrap()).await?,
593 ));
594 }
595 if self.is_append_only {
596 self.append_only(chunk).await
597 } else {
598 self.upsert(chunk).await
599 }
600 }
601
602 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
603 if let Some(client) = self.client.take() {
604 client.finish().await?;
610 }
611
612 if is_checkpoint
613 && let Some(txn_label) = self.curr_txn_label.take()
614 && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
615 {
616 match self.txn_client.rollback(txn_label.clone()).await {
617 Ok(_) => tracing::warn!(
618 ?txn_label,
619 "transaction is successfully rolled back due to commit failure"
620 ),
621 Err(err) => {
622 tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
623 }
624 }
625
626 return Err(err);
627 }
628 Ok(())
629 }
630
631 async fn abort(&mut self) -> Result<()> {
632 if let Some(txn_label) = self.curr_txn_label.take() {
633 tracing::debug!(?txn_label, "rollback transaction");
634 self.txn_client.rollback(txn_label).await?;
635 }
636 Ok(())
637 }
638}
639
640pub struct StarrocksSchemaClient {
641 table: String,
642 db: String,
643 conn: mysql_async::Conn,
644}
645
646impl StarrocksSchemaClient {
647 pub async fn new(
648 host: String,
649 port: String,
650 table: String,
651 db: String,
652 user: String,
653 password: String,
654 ) -> Result<Self> {
655 let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
658 let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
659
660 let conn_uri = format!(
661 "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
662 user,
663 password,
664 host,
665 port,
666 db,
667 STARROCK_MYSQL_PREFER_SOCKET,
668 STARROCK_MYSQL_MAX_ALLOWED_PACKET,
669 STARROCK_MYSQL_WAIT_TIMEOUT
670 );
671 let pool = mysql_async::Pool::new(
672 Opts::from_url(&conn_uri)
673 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
674 );
675 let conn = pool
676 .get_conn()
677 .await
678 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
679
680 Ok(Self { table, db, conn })
681 }
682
683 pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
684 let query = format!(
685 "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
686 self.table, self.db
687 );
688 let mut query_map: HashMap<String, String> = HashMap::default();
689 self.conn
690 .query_map(query, |(column_name, column_type)| {
691 query_map.insert(column_name, column_type)
692 })
693 .await
694 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
695 Ok(query_map)
696 }
697
698 pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
699 let query = format!(
700 "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
701 self.table, self.db
702 );
703 let table_mode_pk: (String, String) = self
704 .conn
705 .query_map(
706 query,
707 |(table_model, primary_key, sort_key): (String, String, String)| match table_model
708 .as_str()
709 {
710 "AGG_KEYS" => (table_model, sort_key),
714 _ => (table_model, primary_key),
715 },
716 )
717 .await
718 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
719 .first()
720 .ok_or_else(|| {
721 SinkError::Starrocks(format!(
722 "Can't find schema with table {:?} and database {:?}",
723 self.table, self.db
724 ))
725 })?
726 .clone();
727 Ok(table_mode_pk)
728 }
729}
730
731#[derive(Debug, Serialize, Deserialize)]
732pub struct StarrocksInsertResultResponse {
733 #[serde(rename = "TxnId")]
734 pub txn_id: Option<i64>,
735 #[serde(rename = "Seq")]
736 pub seq: Option<i64>,
737 #[serde(rename = "Label")]
738 pub label: Option<String>,
739 #[serde(rename = "Status")]
740 pub status: String,
741 #[serde(rename = "Message")]
742 pub message: String,
743 #[serde(rename = "NumberTotalRows")]
744 pub number_total_rows: Option<i64>,
745 #[serde(rename = "NumberLoadedRows")]
746 pub number_loaded_rows: Option<i64>,
747 #[serde(rename = "NumberFilteredRows")]
748 pub number_filtered_rows: Option<i32>,
749 #[serde(rename = "NumberUnselectedRows")]
750 pub number_unselected_rows: Option<i32>,
751 #[serde(rename = "LoadBytes")]
752 pub load_bytes: Option<i64>,
753 #[serde(rename = "LoadTimeMs")]
754 pub load_time_ms: Option<i32>,
755 #[serde(rename = "BeginTxnTimeMs")]
756 pub begin_txn_time_ms: Option<i32>,
757 #[serde(rename = "ReadDataTimeMs")]
758 pub read_data_time_ms: Option<i32>,
759 #[serde(rename = "WriteDataTimeMs")]
760 pub write_data_time_ms: Option<i32>,
761 #[serde(rename = "CommitAndPublishTimeMs")]
762 pub commit_and_publish_time_ms: Option<i32>,
763 #[serde(rename = "StreamLoadPlanTimeMs")]
764 pub stream_load_plan_time_ms: Option<i32>,
765 #[serde(rename = "ExistingJobStatus")]
766 pub existing_job_status: Option<String>,
767 #[serde(rename = "ErrorURL")]
768 pub error_url: Option<String>,
769}
770
771pub struct StarrocksClient {
772 insert: InserterInner,
773}
774impl StarrocksClient {
775 pub fn new(insert: InserterInner) -> Self {
776 Self { insert }
777 }
778
779 pub async fn write(&mut self, data: Bytes) -> Result<()> {
780 self.insert.write(data).await?;
781 Ok(())
782 }
783
784 pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
785 let raw = self.insert.finish().await?;
786 let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
787 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
788
789 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
790 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
791 "Insert error: {}, {}, {:?}",
792 res.status,
793 res.message,
794 res.error_url,
795 )));
796 };
797 Ok(res)
798 }
799}
800
801pub struct StarrocksTxnClient {
802 request_builder: StarrocksTxnRequestBuilder,
803}
804
805impl StarrocksTxnClient {
806 pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
807 Self { request_builder }
808 }
809
810 fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
811 let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
812 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
813 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
814 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
815 "transaction error: {}, {}, {:?}",
816 res.status,
817 res.message,
818 res.error_url,
819 )));
820 }
821 res.label.ok_or_else(|| {
822 SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
823 })
824 }
825
826 pub async fn begin(&self, label: String) -> Result<String> {
827 let res = self
828 .request_builder
829 .build_begin_request_sender(label)?
830 .send()
831 .await?;
832 self.check_response_and_extract_label(res)
833 }
834
835 pub async fn prepare(&self, label: String) -> Result<String> {
836 let res = self
837 .request_builder
838 .build_prepare_request_sender(label)?
839 .send()
840 .await?;
841 self.check_response_and_extract_label(res)
842 }
843
844 pub async fn commit(&self, label: String) -> Result<String> {
845 let res = self
846 .request_builder
847 .build_commit_request_sender(label)?
848 .send()
849 .await?;
850 self.check_response_and_extract_label(res)
851 }
852
853 pub async fn rollback(&self, label: String) -> Result<String> {
854 let res = self
855 .request_builder
856 .build_rollback_request_sender(label)?
857 .send()
858 .await?;
859 self.check_response_and_extract_label(res)
860 }
861
862 pub async fn load(&self, label: String) -> Result<InserterInner> {
863 self.request_builder.build_txn_inserter(label).await
864 }
865}