1use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use serde_with::{DisplayFromStr, serde_as};
30use thiserror_ext::AsReport;
31use url::form_urlencoded;
32use with_options::WithOptions;
33
34use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
35use super::doris_starrocks_connector::{
36 HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
37 StarrocksTxnRequestBuilder,
38};
39use super::encoder::{JsonEncoder, RowEncoder};
40use super::{
41 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
42 SinkWriterMetrics,
43};
44use crate::enforce_secret::EnforceSecret;
45use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
46use crate::sink::writer::SinkWriter;
47use crate::sink::{Result, Sink, SinkWriterParam};
48
49pub const STARROCKS_SINK: &str = "starrocks";
50const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
51const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
52const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
53
54pub const fn _default_stream_load_http_timeout_ms() -> u64 {
55 30 * 1000
56}
57
58const fn default_use_https() -> bool {
59 false
60}
61
62#[serde_as]
63#[derive(Deserialize, Debug, Clone, WithOptions)]
64pub struct StarrocksCommon {
65 #[serde(rename = "starrocks.host")]
67 pub host: String,
68 #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
70 pub mysql_port: String,
71 #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
73 pub http_port: String,
74 #[serde(rename = "starrocks.user")]
76 pub user: String,
77 #[serde(rename = "starrocks.password")]
79 pub password: String,
80 #[serde(rename = "starrocks.database")]
82 pub database: String,
83 #[serde(rename = "starrocks.table")]
85 pub table: String,
86
87 #[serde(rename = "starrocks.use_https")]
89 #[serde(default = "default_use_https")]
90 #[serde_as(as = "DisplayFromStr")]
91 pub use_https: bool,
92}
93
94impl EnforceSecret for StarrocksCommon {
95 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
96 "starrocks.password", "starrocks.user"
97 };
98}
99
100#[serde_as]
101#[derive(Clone, Debug, Deserialize, WithOptions)]
102pub struct StarrocksConfig {
103 #[serde(flatten)]
104 pub common: StarrocksCommon,
105
106 #[serde(
108 rename = "starrocks.stream_load.http.timeout.ms",
109 default = "_default_stream_load_http_timeout_ms"
110 )]
111 #[serde_as(as = "DisplayFromStr")]
112 #[with_option(allow_alter_on_fly)]
113 pub stream_load_http_timeout_ms: u64,
114
115 #[serde(default = "default_commit_checkpoint_interval")]
121 #[serde_as(as = "DisplayFromStr")]
122 #[with_option(allow_alter_on_fly)]
123 pub commit_checkpoint_interval: u64,
124
125 #[serde(rename = "starrocks.partial_update")]
127 pub partial_update: Option<String>,
128
129 pub r#type: String, }
131
132impl EnforceSecret for StarrocksConfig {
133 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
134 StarrocksCommon::enforce_one(prop)
135 }
136}
137
138fn default_commit_checkpoint_interval() -> u64 {
139 DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
140}
141
142impl StarrocksConfig {
143 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
144 let config =
145 serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
146 .map_err(|e| SinkError::Config(anyhow!(e)))?;
147 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
148 return Err(SinkError::Config(anyhow!(
149 "`{}` must be {}, or {}",
150 SINK_TYPE_OPTION,
151 SINK_TYPE_APPEND_ONLY,
152 SINK_TYPE_UPSERT
153 )));
154 }
155 if config.commit_checkpoint_interval == 0 {
156 return Err(SinkError::Config(anyhow!(
157 "`commit_checkpoint_interval` must be greater than 0"
158 )));
159 }
160 Ok(config)
161 }
162}
163
164#[derive(Debug)]
165pub struct StarrocksSink {
166 pub config: StarrocksConfig,
167 schema: Schema,
168 pk_indices: Vec<usize>,
169 is_append_only: bool,
170}
171
172impl EnforceSecret for StarrocksSink {
173 fn enforce_secret<'a>(
174 prop_iter: impl Iterator<Item = &'a str>,
175 ) -> crate::error::ConnectorResult<()> {
176 for prop in prop_iter {
177 StarrocksConfig::enforce_one(prop)?;
178 }
179 Ok(())
180 }
181}
182
183impl StarrocksSink {
184 pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
185 let pk_indices = param.downstream_pk_or_empty();
186 let is_append_only = param.sink_type.is_append_only();
187 Ok(Self {
188 config,
189 schema,
190 pk_indices,
191 is_append_only,
192 })
193 }
194}
195
196impl StarrocksSink {
197 fn check_column_name_and_type(
198 &self,
199 starrocks_columns_desc: HashMap<String, String>,
200 ) -> Result<()> {
201 let rw_fields_name = self.schema.fields();
202 if rw_fields_name.len() > starrocks_columns_desc.len() {
203 return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
204 }
205
206 for i in rw_fields_name {
207 let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
208 SinkError::Starrocks(format!(
209 "Column name don't find in starrocks, risingwave is {:?} ",
210 i.name
211 ))
212 })?;
213 if !Self::check_and_correct_column_type(&i.data_type, value)? {
214 return Err(SinkError::Starrocks(format!(
215 "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
216 i.name, value, i.data_type
217 )));
218 }
219 }
220 Ok(())
221 }
222
223 fn check_and_correct_column_type(
224 rw_data_type: &DataType,
225 starrocks_data_type: &String,
226 ) -> Result<bool> {
227 match rw_data_type {
228 risingwave_common::types::DataType::Boolean => {
229 Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
230 }
231 risingwave_common::types::DataType::Int16 => {
232 Ok(starrocks_data_type.contains("smallint"))
233 }
234 risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
235 risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
236 risingwave_common::types::DataType::Float32 => {
237 Ok(starrocks_data_type.contains("float"))
238 }
239 risingwave_common::types::DataType::Float64 => {
240 Ok(starrocks_data_type.contains("double"))
241 }
242 risingwave_common::types::DataType::Decimal => {
243 Ok(starrocks_data_type.contains("decimal"))
244 }
245 risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
246 risingwave_common::types::DataType::Varchar => {
247 Ok(starrocks_data_type.contains("varchar"))
248 }
249 risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
250 "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
251 )),
252 risingwave_common::types::DataType::Timestamp => {
253 Ok(starrocks_data_type.contains("datetime"))
254 }
255 risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
256 "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
257 )),
258 risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
259 "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
260 )),
261 risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
262 "STRUCT is not supported for Starrocks sink.".to_owned(),
263 )),
264 risingwave_common::types::DataType::List(list) => {
265 if starrocks_data_type.contains("unknown") {
267 return Ok(true);
268 }
269 let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
270 Ok(check_result && starrocks_data_type.contains("array"))
271 }
272 risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
273 "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
274 )),
275 risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
276 risingwave_common::types::DataType::Serial => {
277 Ok(starrocks_data_type.contains("bigint"))
278 }
279 risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
280 "INT256 is not supported for Starrocks sink.".to_owned(),
281 )),
282 risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
283 "MAP is not supported for Starrocks sink.".to_owned(),
284 )),
285 DataType::Vector(_) => Err(SinkError::Starrocks(
286 "VECTOR is not supported for Starrocks sink.".to_owned(),
287 )),
288 }
289 }
290}
291
292impl Sink for StarrocksSink {
293 type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
294
295 const SINK_NAME: &'static str = STARROCKS_SINK;
296
297 async fn validate(&self) -> Result<()> {
298 if !self.is_append_only && self.pk_indices.is_empty() {
299 return Err(SinkError::Config(anyhow!(
300 "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
301 )));
302 }
303 let mut client = StarrocksSchemaClient::new(
305 self.config.common.host.clone(),
306 self.config.common.mysql_port.clone(),
307 self.config.common.table.clone(),
308 self.config.common.database.clone(),
309 self.config.common.user.clone(),
310 self.config.common.password.clone(),
311 )
312 .await?;
313 let (read_model, pks) = client.get_pk_from_starrocks().await?;
314
315 if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
316 return Err(SinkError::Config(anyhow!(
317 "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
318 )));
319 }
320
321 for (index, filed) in self.schema.fields().iter().enumerate() {
322 if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
323 return Err(SinkError::Starrocks(format!(
324 "Can't find pk {:?} in starrocks",
325 filed.name
326 )));
327 }
328 }
329
330 let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
331
332 self.check_column_name_and_type(starrocks_columns_desc)?;
333 Ok(())
334 }
335
336 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
337 StarrocksConfig::from_btreemap(config.clone())?;
338 Ok(())
339 }
340
341 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
342 let commit_checkpoint_interval =
343 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
344 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
345 );
346
347 let writer = StarrocksSinkWriter::new(
348 self.config.clone(),
349 self.schema.clone(),
350 self.pk_indices.clone(),
351 self.is_append_only,
352 writer_param.executor_id,
353 )?;
354
355 let metrics = SinkWriterMetrics::new(&writer_param);
356
357 Ok(DecoupleCheckpointLogSinkerOf::new(
358 writer,
359 metrics,
360 commit_checkpoint_interval,
361 ))
362 }
363}
364
365pub struct StarrocksSinkWriter {
366 pub config: StarrocksConfig,
367 #[expect(dead_code)]
368 schema: Schema,
369 #[expect(dead_code)]
370 pk_indices: Vec<usize>,
371 is_append_only: bool,
372 client: Option<StarrocksClient>,
373 txn_client: Arc<StarrocksTxnClient>,
374 row_encoder: JsonEncoder,
375 executor_id: u64,
376 curr_txn_label: Option<String>,
377}
378
379impl TryFrom<SinkParam> for StarrocksSink {
380 type Error = SinkError;
381
382 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
383 let schema = param.schema();
384 let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
385 StarrocksSink::new(param, config, schema)
386 }
387}
388
389impl StarrocksSinkWriter {
390 pub fn new(
391 config: StarrocksConfig,
392 schema: Schema,
393 pk_indices: Vec<usize>,
394 is_append_only: bool,
395 executor_id: u64,
396 ) -> Result<Self> {
397 let mut field_names = schema.names_str();
398 if !is_append_only {
399 field_names.push(STARROCKS_DELETE_SIGN);
400 };
401 let field_names = field_names
404 .into_iter()
405 .map(|name| format!("`{}`", name))
406 .collect::<Vec<String>>();
407 let field_names_str = field_names
408 .iter()
409 .map(|name| name.as_str())
410 .collect::<Vec<&str>>();
411
412 let header = HeaderBuilder::new()
413 .add_common_header()
414 .set_user_password(config.common.user.clone(), config.common.password.clone())
415 .add_json_format()
416 .set_partial_update(config.partial_update.clone())
417 .set_columns_name(field_names_str)
418 .set_db(config.common.database.clone())
419 .set_table(config.common.table.clone())
420 .build();
421
422 let url = if config.common.use_https {
423 format!("https://{}:{}", config.common.host, config.common.http_port)
424 } else {
425 format!("http://{}:{}", config.common.host, config.common.http_port)
426 };
427 let txn_request_builder =
428 StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
429
430 Ok(Self {
431 config,
432 schema: schema.clone(),
433 pk_indices,
434 is_append_only,
435 client: None,
436 txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
437 row_encoder: JsonEncoder::new_with_starrocks(schema, None),
438 executor_id,
439 curr_txn_label: None,
440 })
441 }
442
443 async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
444 for (op, row) in chunk.rows() {
445 if op != Op::Insert {
446 continue;
447 }
448 let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
449 self.client
450 .as_mut()
451 .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
452 .write(row_json_string.into())
453 .await?;
454 }
455 Ok(())
456 }
457
458 async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
459 for (op, row) in chunk.rows() {
460 match op {
461 Op::Insert => {
462 let mut row_json_value = self.row_encoder.encode(row)?;
463 row_json_value.insert(
464 STARROCKS_DELETE_SIGN.to_owned(),
465 Value::String("0".to_owned()),
466 );
467 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
468 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
469 })?;
470 self.client
471 .as_mut()
472 .ok_or_else(|| {
473 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
474 })?
475 .write(row_json_string.into())
476 .await?;
477 }
478 Op::Delete => {
479 let mut row_json_value = self.row_encoder.encode(row)?;
480 row_json_value.insert(
481 STARROCKS_DELETE_SIGN.to_owned(),
482 Value::String("1".to_owned()),
483 );
484 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
485 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
486 })?;
487 self.client
488 .as_mut()
489 .ok_or_else(|| {
490 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
491 })?
492 .write(row_json_string.into())
493 .await?;
494 }
495 Op::UpdateDelete => {}
496 Op::UpdateInsert => {
497 let mut row_json_value = self.row_encoder.encode(row)?;
498 row_json_value.insert(
499 STARROCKS_DELETE_SIGN.to_owned(),
500 Value::String("0".to_owned()),
501 );
502 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
503 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
504 })?;
505 self.client
506 .as_mut()
507 .ok_or_else(|| {
508 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
509 })?
510 .write(row_json_string.into())
511 .await?;
512 }
513 }
514 }
515 Ok(())
516 }
517
518 #[inline(always)]
520 fn new_txn_label(&self) -> String {
521 format!(
522 "rw-txn-{}-{}",
523 self.executor_id,
524 chrono::Utc::now().timestamp_micros()
525 )
526 }
527
528 async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
529 tracing::debug!(?txn_label, "prepare transaction");
530 let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
531 if txn_label != txn_label_res {
532 return Err(SinkError::Starrocks(format!(
533 "label {} returned from prepare transaction {} differs from the current one",
534 txn_label, txn_label_res
535 )));
536 }
537 tracing::debug!(?txn_label, "commit transaction");
538 let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
539 if txn_label != txn_label_res {
540 return Err(SinkError::Starrocks(format!(
541 "label {} returned from commit transaction {} differs from the current one",
542 txn_label, txn_label_res
543 )));
544 }
545 Ok(())
546 }
547}
548
549impl Drop for StarrocksSinkWriter {
550 fn drop(&mut self) {
551 if let Some(txn_label) = self.curr_txn_label.take() {
552 let txn_client = self.txn_client.clone();
553 tokio::spawn(async move {
554 if let Err(e) = txn_client.rollback(txn_label.clone()).await {
555 tracing::error!(
556 "starrocks rollback transaction error: {:?}, txn label: {}",
557 e.as_report(),
558 txn_label
559 );
560 }
561 });
562 }
563 }
564}
565
566#[async_trait]
567impl SinkWriter for StarrocksSinkWriter {
568 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
569 Ok(())
570 }
571
572 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
573 if self.curr_txn_label.is_none() {
577 let txn_label = self.new_txn_label();
578 tracing::debug!(?txn_label, "begin transaction");
579 let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
580 if txn_label != txn_label_res {
581 return Err(SinkError::Starrocks(format!(
582 "label {} returned from StarRocks {} differs from generated one",
583 txn_label, txn_label_res
584 )));
585 }
586 self.curr_txn_label = Some(txn_label.clone());
587 }
588 if self.client.is_none() {
589 let txn_label = self.curr_txn_label.clone();
590 self.client = Some(StarrocksClient::new(
591 self.txn_client.load(txn_label.unwrap()).await?,
592 ));
593 }
594 if self.is_append_only {
595 self.append_only(chunk).await
596 } else {
597 self.upsert(chunk).await
598 }
599 }
600
601 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
602 if let Some(client) = self.client.take() {
603 client.finish().await?;
609 }
610
611 if is_checkpoint
612 && let Some(txn_label) = self.curr_txn_label.take()
613 && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
614 {
615 match self.txn_client.rollback(txn_label.clone()).await {
616 Ok(_) => tracing::warn!(
617 ?txn_label,
618 "transaction is successfully rolled back due to commit failure"
619 ),
620 Err(err) => {
621 tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
622 }
623 }
624
625 return Err(err);
626 }
627 Ok(())
628 }
629
630 async fn abort(&mut self) -> Result<()> {
631 if let Some(txn_label) = self.curr_txn_label.take() {
632 tracing::debug!(?txn_label, "rollback transaction");
633 self.txn_client.rollback(txn_label).await?;
634 }
635 Ok(())
636 }
637}
638
639pub struct StarrocksSchemaClient {
640 table: String,
641 db: String,
642 conn: mysql_async::Conn,
643}
644
645impl StarrocksSchemaClient {
646 pub async fn new(
647 host: String,
648 port: String,
649 table: String,
650 db: String,
651 user: String,
652 password: String,
653 ) -> Result<Self> {
654 let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
657 let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
658
659 let conn_uri = format!(
660 "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
661 user,
662 password,
663 host,
664 port,
665 db,
666 STARROCK_MYSQL_PREFER_SOCKET,
667 STARROCK_MYSQL_MAX_ALLOWED_PACKET,
668 STARROCK_MYSQL_WAIT_TIMEOUT
669 );
670 let pool = mysql_async::Pool::new(
671 Opts::from_url(&conn_uri)
672 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
673 );
674 let conn = pool
675 .get_conn()
676 .await
677 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
678
679 Ok(Self { table, db, conn })
680 }
681
682 pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
683 let query = format!(
684 "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
685 self.table, self.db
686 );
687 let mut query_map: HashMap<String, String> = HashMap::default();
688 self.conn
689 .query_map(query, |(column_name, column_type)| {
690 query_map.insert(column_name, column_type)
691 })
692 .await
693 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
694 Ok(query_map)
695 }
696
697 pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
698 let query = format!(
699 "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
700 self.table, self.db
701 );
702 let table_mode_pk: (String, String) = self
703 .conn
704 .query_map(
705 query,
706 |(table_model, primary_key, sort_key): (String, String, String)| match table_model
707 .as_str()
708 {
709 "AGG_KEYS" => (table_model, sort_key),
713 _ => (table_model, primary_key),
714 },
715 )
716 .await
717 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
718 .first()
719 .ok_or_else(|| {
720 SinkError::Starrocks(format!(
721 "Can't find schema with table {:?} and database {:?}",
722 self.table, self.db
723 ))
724 })?
725 .clone();
726 Ok(table_mode_pk)
727 }
728}
729
730#[derive(Debug, Serialize, Deserialize)]
731pub struct StarrocksInsertResultResponse {
732 #[serde(rename = "TxnId")]
733 pub txn_id: Option<i64>,
734 #[serde(rename = "Seq")]
735 pub seq: Option<i64>,
736 #[serde(rename = "Label")]
737 pub label: Option<String>,
738 #[serde(rename = "Status")]
739 pub status: String,
740 #[serde(rename = "Message")]
741 pub message: String,
742 #[serde(rename = "NumberTotalRows")]
743 pub number_total_rows: Option<i64>,
744 #[serde(rename = "NumberLoadedRows")]
745 pub number_loaded_rows: Option<i64>,
746 #[serde(rename = "NumberFilteredRows")]
747 pub number_filtered_rows: Option<i32>,
748 #[serde(rename = "NumberUnselectedRows")]
749 pub number_unselected_rows: Option<i32>,
750 #[serde(rename = "LoadBytes")]
751 pub load_bytes: Option<i64>,
752 #[serde(rename = "LoadTimeMs")]
753 pub load_time_ms: Option<i32>,
754 #[serde(rename = "BeginTxnTimeMs")]
755 pub begin_txn_time_ms: Option<i32>,
756 #[serde(rename = "ReadDataTimeMs")]
757 pub read_data_time_ms: Option<i32>,
758 #[serde(rename = "WriteDataTimeMs")]
759 pub write_data_time_ms: Option<i32>,
760 #[serde(rename = "CommitAndPublishTimeMs")]
761 pub commit_and_publish_time_ms: Option<i32>,
762 #[serde(rename = "StreamLoadPlanTimeMs")]
763 pub stream_load_plan_time_ms: Option<i32>,
764 #[serde(rename = "ExistingJobStatus")]
765 pub existing_job_status: Option<String>,
766 #[serde(rename = "ErrorURL")]
767 pub error_url: Option<String>,
768}
769
770pub struct StarrocksClient {
771 insert: InserterInner,
772}
773impl StarrocksClient {
774 pub fn new(insert: InserterInner) -> Self {
775 Self { insert }
776 }
777
778 pub async fn write(&mut self, data: Bytes) -> Result<()> {
779 self.insert.write(data).await?;
780 Ok(())
781 }
782
783 pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
784 let raw = self.insert.finish().await?;
785 let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
786 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
787
788 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
789 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
790 "Insert error: {}, {}, {:?}",
791 res.status,
792 res.message,
793 res.error_url,
794 )));
795 };
796 Ok(res)
797 }
798}
799
800pub struct StarrocksTxnClient {
801 request_builder: StarrocksTxnRequestBuilder,
802}
803
804impl StarrocksTxnClient {
805 pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
806 Self { request_builder }
807 }
808
809 fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
810 let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
811 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
812 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
813 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
814 "transaction error: {}, {}, {:?}",
815 res.status,
816 res.message,
817 res.error_url,
818 )));
819 }
820 res.label.ok_or_else(|| {
821 SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
822 })
823 }
824
825 pub async fn begin(&self, label: String) -> Result<String> {
826 let res = self
827 .request_builder
828 .build_begin_request_sender(label)?
829 .send()
830 .await?;
831 self.check_response_and_extract_label(res)
832 }
833
834 pub async fn prepare(&self, label: String) -> Result<String> {
835 let res = self
836 .request_builder
837 .build_prepare_request_sender(label)?
838 .send()
839 .await?;
840 self.check_response_and_extract_label(res)
841 }
842
843 pub async fn commit(&self, label: String) -> Result<String> {
844 let res = self
845 .request_builder
846 .build_commit_request_sender(label)?
847 .send()
848 .await?;
849 self.check_response_and_extract_label(res)
850 }
851
852 pub async fn rollback(&self, label: String) -> Result<String> {
853 let res = self
854 .request_builder
855 .build_rollback_request_sender(label)?
856 .send()
857 .await?;
858 self.check_response_and_extract_label(res)
859 }
860
861 pub async fn load(&self, label: String) -> Result<InserterInner> {
862 self.request_builder.build_txn_inserter(label).await
863 }
864}