1use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use serde_with::{DisplayFromStr, serde_as};
30use thiserror_ext::AsReport;
31use url::form_urlencoded;
32use with_options::WithOptions;
33
34use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
35use super::doris_starrocks_connector::{
36 HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
37 StarrocksTxnRequestBuilder,
38};
39use super::encoder::{JsonEncoder, RowEncoder};
40use super::{
41 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
42 SinkWriterMetrics,
43};
44use crate::enforce_secret::EnforceSecret;
45use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
46use crate::sink::{Result, Sink, SinkWriter, SinkWriterParam};
47
48pub const STARROCKS_SINK: &str = "starrocks";
49const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
50const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
51const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
52
53pub const fn _default_stream_load_http_timeout_ms() -> u64 {
54 30 * 1000
55}
56
57const fn default_use_https() -> bool {
58 false
59}
60
61#[serde_as]
62#[derive(Deserialize, Debug, Clone, WithOptions)]
63pub struct StarrocksCommon {
64 #[serde(rename = "starrocks.host")]
66 pub host: String,
67 #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
69 pub mysql_port: String,
70 #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
72 pub http_port: String,
73 #[serde(rename = "starrocks.user")]
75 pub user: String,
76 #[serde(rename = "starrocks.password")]
78 pub password: String,
79 #[serde(rename = "starrocks.database")]
81 pub database: String,
82 #[serde(rename = "starrocks.table")]
84 pub table: String,
85
86 #[serde(rename = "starrocks.use_https")]
88 #[serde(default = "default_use_https")]
89 #[serde_as(as = "DisplayFromStr")]
90 pub use_https: bool,
91}
92
93impl EnforceSecret for StarrocksCommon {
94 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
95 "starrocks.password", "starrocks.user"
96 };
97}
98
99#[serde_as]
100#[derive(Clone, Debug, Deserialize, WithOptions)]
101pub struct StarrocksConfig {
102 #[serde(flatten)]
103 pub common: StarrocksCommon,
104
105 #[serde(
107 rename = "starrocks.stream_load.http.timeout.ms",
108 default = "_default_stream_load_http_timeout_ms"
109 )]
110 #[serde_as(as = "DisplayFromStr")]
111 #[with_option(allow_alter_on_fly)]
112 pub stream_load_http_timeout_ms: u64,
113
114 #[serde(default = "default_commit_checkpoint_interval")]
120 #[serde_as(as = "DisplayFromStr")]
121 #[with_option(allow_alter_on_fly)]
122 pub commit_checkpoint_interval: u64,
123
124 #[serde(rename = "starrocks.partial_update")]
126 pub partial_update: Option<String>,
127
128 pub r#type: String, }
130
131impl EnforceSecret for StarrocksConfig {
132 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
133 StarrocksCommon::enforce_one(prop)
134 }
135}
136
137fn default_commit_checkpoint_interval() -> u64 {
138 DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
139}
140
141impl StarrocksConfig {
142 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
143 let config =
144 serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
145 .map_err(|e| SinkError::Config(anyhow!(e)))?;
146 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
147 return Err(SinkError::Config(anyhow!(
148 "`{}` must be {}, or {}",
149 SINK_TYPE_OPTION,
150 SINK_TYPE_APPEND_ONLY,
151 SINK_TYPE_UPSERT
152 )));
153 }
154 if config.commit_checkpoint_interval == 0 {
155 return Err(SinkError::Config(anyhow!(
156 "`commit_checkpoint_interval` must be greater than 0"
157 )));
158 }
159 Ok(config)
160 }
161}
162
163#[derive(Debug)]
164pub struct StarrocksSink {
165 pub config: StarrocksConfig,
166 schema: Schema,
167 pk_indices: Vec<usize>,
168 is_append_only: bool,
169}
170
171impl EnforceSecret for StarrocksSink {
172 fn enforce_secret<'a>(
173 prop_iter: impl Iterator<Item = &'a str>,
174 ) -> crate::error::ConnectorResult<()> {
175 for prop in prop_iter {
176 StarrocksConfig::enforce_one(prop)?;
177 }
178 Ok(())
179 }
180}
181
182impl StarrocksSink {
183 pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
184 let pk_indices = param.downstream_pk.clone();
185 let is_append_only = param.sink_type.is_append_only();
186 Ok(Self {
187 config,
188 schema,
189 pk_indices,
190 is_append_only,
191 })
192 }
193}
194
195impl StarrocksSink {
196 fn check_column_name_and_type(
197 &self,
198 starrocks_columns_desc: HashMap<String, String>,
199 ) -> Result<()> {
200 let rw_fields_name = self.schema.fields();
201 if rw_fields_name.len() > starrocks_columns_desc.len() {
202 return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
203 }
204
205 for i in rw_fields_name {
206 let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
207 SinkError::Starrocks(format!(
208 "Column name don't find in starrocks, risingwave is {:?} ",
209 i.name
210 ))
211 })?;
212 if !Self::check_and_correct_column_type(&i.data_type, value)? {
213 return Err(SinkError::Starrocks(format!(
214 "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
215 i.name, value, i.data_type
216 )));
217 }
218 }
219 Ok(())
220 }
221
222 fn check_and_correct_column_type(
223 rw_data_type: &DataType,
224 starrocks_data_type: &String,
225 ) -> Result<bool> {
226 match rw_data_type {
227 risingwave_common::types::DataType::Boolean => {
228 Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
229 }
230 risingwave_common::types::DataType::Int16 => {
231 Ok(starrocks_data_type.contains("smallint"))
232 }
233 risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
234 risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
235 risingwave_common::types::DataType::Float32 => {
236 Ok(starrocks_data_type.contains("float"))
237 }
238 risingwave_common::types::DataType::Float64 => {
239 Ok(starrocks_data_type.contains("double"))
240 }
241 risingwave_common::types::DataType::Decimal => {
242 Ok(starrocks_data_type.contains("decimal"))
243 }
244 risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
245 risingwave_common::types::DataType::Varchar => {
246 Ok(starrocks_data_type.contains("varchar"))
247 }
248 risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
249 "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
250 )),
251 risingwave_common::types::DataType::Timestamp => {
252 Ok(starrocks_data_type.contains("datetime"))
253 }
254 risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
255 "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
256 )),
257 risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
258 "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
259 )),
260 risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
261 "STRUCT is not supported for Starrocks sink.".to_owned(),
262 )),
263 risingwave_common::types::DataType::List(list) => {
264 if starrocks_data_type.contains("unknown") {
266 return Ok(true);
267 }
268 let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
269 Ok(check_result && starrocks_data_type.contains("array"))
270 }
271 risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
272 "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
273 )),
274 risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
275 risingwave_common::types::DataType::Serial => {
276 Ok(starrocks_data_type.contains("bigint"))
277 }
278 risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
279 "INT256 is not supported for Starrocks sink.".to_owned(),
280 )),
281 risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
282 "MAP is not supported for Starrocks sink.".to_owned(),
283 )),
284 DataType::Vector(_) => Err(SinkError::Starrocks(
285 "VECTOR is not supported for Starrocks sink.".to_owned(),
286 )),
287 }
288 }
289}
290
291impl Sink for StarrocksSink {
292 type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
293
294 const SINK_NAME: &'static str = STARROCKS_SINK;
295
296 async fn validate(&self) -> Result<()> {
297 if !self.is_append_only && self.pk_indices.is_empty() {
298 return Err(SinkError::Config(anyhow!(
299 "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
300 )));
301 }
302 let mut client = StarrocksSchemaClient::new(
304 self.config.common.host.clone(),
305 self.config.common.mysql_port.clone(),
306 self.config.common.table.clone(),
307 self.config.common.database.clone(),
308 self.config.common.user.clone(),
309 self.config.common.password.clone(),
310 )
311 .await?;
312 let (read_model, pks) = client.get_pk_from_starrocks().await?;
313
314 if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
315 return Err(SinkError::Config(anyhow!(
316 "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
317 )));
318 }
319
320 for (index, filed) in self.schema.fields().iter().enumerate() {
321 if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
322 return Err(SinkError::Starrocks(format!(
323 "Can't find pk {:?} in starrocks",
324 filed.name
325 )));
326 }
327 }
328
329 let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
330
331 self.check_column_name_and_type(starrocks_columns_desc)?;
332 Ok(())
333 }
334
335 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
336 StarrocksConfig::from_btreemap(config.clone())?;
337 Ok(())
338 }
339
340 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
341 let commit_checkpoint_interval =
342 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
343 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
344 );
345
346 let writer = StarrocksSinkWriter::new(
347 self.config.clone(),
348 self.schema.clone(),
349 self.pk_indices.clone(),
350 self.is_append_only,
351 writer_param.executor_id,
352 )?;
353
354 let metrics = SinkWriterMetrics::new(&writer_param);
355
356 Ok(DecoupleCheckpointLogSinkerOf::new(
357 writer,
358 metrics,
359 commit_checkpoint_interval,
360 ))
361 }
362}
363
364pub struct StarrocksSinkWriter {
365 pub config: StarrocksConfig,
366 #[expect(dead_code)]
367 schema: Schema,
368 #[expect(dead_code)]
369 pk_indices: Vec<usize>,
370 is_append_only: bool,
371 client: Option<StarrocksClient>,
372 txn_client: Arc<StarrocksTxnClient>,
373 row_encoder: JsonEncoder,
374 executor_id: u64,
375 curr_txn_label: Option<String>,
376}
377
378impl TryFrom<SinkParam> for StarrocksSink {
379 type Error = SinkError;
380
381 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
382 let schema = param.schema();
383 let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
384 StarrocksSink::new(param, config, schema)
385 }
386}
387
388impl StarrocksSinkWriter {
389 pub fn new(
390 config: StarrocksConfig,
391 schema: Schema,
392 pk_indices: Vec<usize>,
393 is_append_only: bool,
394 executor_id: u64,
395 ) -> Result<Self> {
396 let mut field_names = schema.names_str();
397 if !is_append_only {
398 field_names.push(STARROCKS_DELETE_SIGN);
399 };
400 let field_names = field_names
403 .into_iter()
404 .map(|name| format!("`{}`", name))
405 .collect::<Vec<String>>();
406 let field_names_str = field_names
407 .iter()
408 .map(|name| name.as_str())
409 .collect::<Vec<&str>>();
410
411 let header = HeaderBuilder::new()
412 .add_common_header()
413 .set_user_password(config.common.user.clone(), config.common.password.clone())
414 .add_json_format()
415 .set_partial_update(config.partial_update.clone())
416 .set_columns_name(field_names_str)
417 .set_db(config.common.database.clone())
418 .set_table(config.common.table.clone())
419 .build();
420
421 let url = if config.common.use_https {
422 format!("https://{}:{}", config.common.host, config.common.http_port)
423 } else {
424 format!("http://{}:{}", config.common.host, config.common.http_port)
425 };
426 let txn_request_builder =
427 StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
428
429 Ok(Self {
430 config,
431 schema: schema.clone(),
432 pk_indices,
433 is_append_only,
434 client: None,
435 txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
436 row_encoder: JsonEncoder::new_with_starrocks(schema, None),
437 executor_id,
438 curr_txn_label: None,
439 })
440 }
441
442 async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
443 for (op, row) in chunk.rows() {
444 if op != Op::Insert {
445 continue;
446 }
447 let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
448 self.client
449 .as_mut()
450 .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
451 .write(row_json_string.into())
452 .await?;
453 }
454 Ok(())
455 }
456
457 async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
458 for (op, row) in chunk.rows() {
459 match op {
460 Op::Insert => {
461 let mut row_json_value = self.row_encoder.encode(row)?;
462 row_json_value.insert(
463 STARROCKS_DELETE_SIGN.to_owned(),
464 Value::String("0".to_owned()),
465 );
466 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
467 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
468 })?;
469 self.client
470 .as_mut()
471 .ok_or_else(|| {
472 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
473 })?
474 .write(row_json_string.into())
475 .await?;
476 }
477 Op::Delete => {
478 let mut row_json_value = self.row_encoder.encode(row)?;
479 row_json_value.insert(
480 STARROCKS_DELETE_SIGN.to_owned(),
481 Value::String("1".to_owned()),
482 );
483 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
484 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
485 })?;
486 self.client
487 .as_mut()
488 .ok_or_else(|| {
489 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
490 })?
491 .write(row_json_string.into())
492 .await?;
493 }
494 Op::UpdateDelete => {}
495 Op::UpdateInsert => {
496 let mut row_json_value = self.row_encoder.encode(row)?;
497 row_json_value.insert(
498 STARROCKS_DELETE_SIGN.to_owned(),
499 Value::String("0".to_owned()),
500 );
501 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
502 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
503 })?;
504 self.client
505 .as_mut()
506 .ok_or_else(|| {
507 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
508 })?
509 .write(row_json_string.into())
510 .await?;
511 }
512 }
513 }
514 Ok(())
515 }
516
517 #[inline(always)]
519 fn new_txn_label(&self) -> String {
520 format!(
521 "rw-txn-{}-{}",
522 self.executor_id,
523 chrono::Utc::now().timestamp_micros()
524 )
525 }
526
527 async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
528 tracing::debug!(?txn_label, "prepare transaction");
529 let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
530 if txn_label != txn_label_res {
531 return Err(SinkError::Starrocks(format!(
532 "label {} returned from prepare transaction {} differs from the current one",
533 txn_label, txn_label_res
534 )));
535 }
536 tracing::debug!(?txn_label, "commit transaction");
537 let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
538 if txn_label != txn_label_res {
539 return Err(SinkError::Starrocks(format!(
540 "label {} returned from commit transaction {} differs from the current one",
541 txn_label, txn_label_res
542 )));
543 }
544 Ok(())
545 }
546}
547
548impl Drop for StarrocksSinkWriter {
549 fn drop(&mut self) {
550 if let Some(txn_label) = self.curr_txn_label.take() {
551 let txn_client = self.txn_client.clone();
552 tokio::spawn(async move {
553 if let Err(e) = txn_client.rollback(txn_label.clone()).await {
554 tracing::error!(
555 "starrocks rollback transaction error: {:?}, txn label: {}",
556 e.as_report(),
557 txn_label
558 );
559 }
560 });
561 }
562 }
563}
564
565#[async_trait]
566impl SinkWriter for StarrocksSinkWriter {
567 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
568 Ok(())
569 }
570
571 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
572 if self.curr_txn_label.is_none() {
576 let txn_label = self.new_txn_label();
577 tracing::debug!(?txn_label, "begin transaction");
578 let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
579 if txn_label != txn_label_res {
580 return Err(SinkError::Starrocks(format!(
581 "label {} returned from StarRocks {} differs from generated one",
582 txn_label, txn_label_res
583 )));
584 }
585 self.curr_txn_label = Some(txn_label.clone());
586 }
587 if self.client.is_none() {
588 let txn_label = self.curr_txn_label.clone();
589 self.client = Some(StarrocksClient::new(
590 self.txn_client.load(txn_label.unwrap()).await?,
591 ));
592 }
593 if self.is_append_only {
594 self.append_only(chunk).await
595 } else {
596 self.upsert(chunk).await
597 }
598 }
599
600 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
601 if let Some(client) = self.client.take() {
602 client.finish().await?;
608 }
609
610 if is_checkpoint
611 && let Some(txn_label) = self.curr_txn_label.take()
612 && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
613 {
614 match self.txn_client.rollback(txn_label.clone()).await {
615 Ok(_) => tracing::warn!(
616 ?txn_label,
617 "transaction is successfully rolled back due to commit failure"
618 ),
619 Err(err) => {
620 tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
621 }
622 }
623
624 return Err(err);
625 }
626 Ok(())
627 }
628
629 async fn abort(&mut self) -> Result<()> {
630 if let Some(txn_label) = self.curr_txn_label.take() {
631 tracing::debug!(?txn_label, "rollback transaction");
632 self.txn_client.rollback(txn_label).await?;
633 }
634 Ok(())
635 }
636}
637
638pub struct StarrocksSchemaClient {
639 table: String,
640 db: String,
641 conn: mysql_async::Conn,
642}
643
644impl StarrocksSchemaClient {
645 pub async fn new(
646 host: String,
647 port: String,
648 table: String,
649 db: String,
650 user: String,
651 password: String,
652 ) -> Result<Self> {
653 let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
656 let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
657
658 let conn_uri = format!(
659 "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
660 user,
661 password,
662 host,
663 port,
664 db,
665 STARROCK_MYSQL_PREFER_SOCKET,
666 STARROCK_MYSQL_MAX_ALLOWED_PACKET,
667 STARROCK_MYSQL_WAIT_TIMEOUT
668 );
669 let pool = mysql_async::Pool::new(
670 Opts::from_url(&conn_uri)
671 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
672 );
673 let conn = pool
674 .get_conn()
675 .await
676 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
677
678 Ok(Self { table, db, conn })
679 }
680
681 pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
682 let query = format!(
683 "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
684 self.table, self.db
685 );
686 let mut query_map: HashMap<String, String> = HashMap::default();
687 self.conn
688 .query_map(query, |(column_name, column_type)| {
689 query_map.insert(column_name, column_type)
690 })
691 .await
692 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
693 Ok(query_map)
694 }
695
696 pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
697 let query = format!(
698 "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
699 self.table, self.db
700 );
701 let table_mode_pk: (String, String) = self
702 .conn
703 .query_map(
704 query,
705 |(table_model, primary_key, sort_key): (String, String, String)| match table_model
706 .as_str()
707 {
708 "AGG_KEYS" => (table_model, sort_key),
712 _ => (table_model, primary_key),
713 },
714 )
715 .await
716 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
717 .first()
718 .ok_or_else(|| {
719 SinkError::Starrocks(format!(
720 "Can't find schema with table {:?} and database {:?}",
721 self.table, self.db
722 ))
723 })?
724 .clone();
725 Ok(table_mode_pk)
726 }
727}
728
729#[derive(Debug, Serialize, Deserialize)]
730pub struct StarrocksInsertResultResponse {
731 #[serde(rename = "TxnId")]
732 pub txn_id: Option<i64>,
733 #[serde(rename = "Seq")]
734 pub seq: Option<i64>,
735 #[serde(rename = "Label")]
736 pub label: Option<String>,
737 #[serde(rename = "Status")]
738 pub status: String,
739 #[serde(rename = "Message")]
740 pub message: String,
741 #[serde(rename = "NumberTotalRows")]
742 pub number_total_rows: Option<i64>,
743 #[serde(rename = "NumberLoadedRows")]
744 pub number_loaded_rows: Option<i64>,
745 #[serde(rename = "NumberFilteredRows")]
746 pub number_filtered_rows: Option<i32>,
747 #[serde(rename = "NumberUnselectedRows")]
748 pub number_unselected_rows: Option<i32>,
749 #[serde(rename = "LoadBytes")]
750 pub load_bytes: Option<i64>,
751 #[serde(rename = "LoadTimeMs")]
752 pub load_time_ms: Option<i32>,
753 #[serde(rename = "BeginTxnTimeMs")]
754 pub begin_txn_time_ms: Option<i32>,
755 #[serde(rename = "ReadDataTimeMs")]
756 pub read_data_time_ms: Option<i32>,
757 #[serde(rename = "WriteDataTimeMs")]
758 pub write_data_time_ms: Option<i32>,
759 #[serde(rename = "CommitAndPublishTimeMs")]
760 pub commit_and_publish_time_ms: Option<i32>,
761 #[serde(rename = "StreamLoadPlanTimeMs")]
762 pub stream_load_plan_time_ms: Option<i32>,
763 #[serde(rename = "ExistingJobStatus")]
764 pub existing_job_status: Option<String>,
765 #[serde(rename = "ErrorURL")]
766 pub error_url: Option<String>,
767}
768
769pub struct StarrocksClient {
770 insert: InserterInner,
771}
772impl StarrocksClient {
773 pub fn new(insert: InserterInner) -> Self {
774 Self { insert }
775 }
776
777 pub async fn write(&mut self, data: Bytes) -> Result<()> {
778 self.insert.write(data).await?;
779 Ok(())
780 }
781
782 pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
783 let raw = self.insert.finish().await?;
784 let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
785 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
786
787 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
788 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
789 "Insert error: {}, {}, {:?}",
790 res.status,
791 res.message,
792 res.error_url,
793 )));
794 };
795 Ok(res)
796 }
797}
798
799pub struct StarrocksTxnClient {
800 request_builder: StarrocksTxnRequestBuilder,
801}
802
803impl StarrocksTxnClient {
804 pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
805 Self { request_builder }
806 }
807
808 fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
809 let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
810 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
811 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
812 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
813 "transaction error: {}, {}, {:?}",
814 res.status,
815 res.message,
816 res.error_url,
817 )));
818 }
819 res.label.ok_or_else(|| {
820 SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
821 })
822 }
823
824 pub async fn begin(&self, label: String) -> Result<String> {
825 let res = self
826 .request_builder
827 .build_begin_request_sender(label)?
828 .send()
829 .await?;
830 self.check_response_and_extract_label(res)
831 }
832
833 pub async fn prepare(&self, label: String) -> Result<String> {
834 let res = self
835 .request_builder
836 .build_prepare_request_sender(label)?
837 .send()
838 .await?;
839 self.check_response_and_extract_label(res)
840 }
841
842 pub async fn commit(&self, label: String) -> Result<String> {
843 let res = self
844 .request_builder
845 .build_commit_request_sender(label)?
846 .send()
847 .await?;
848 self.check_response_and_extract_label(res)
849 }
850
851 pub async fn rollback(&self, label: String) -> Result<String> {
852 let res = self
853 .request_builder
854 .build_rollback_request_sender(label)?
855 .send()
856 .await?;
857 self.check_response_and_extract_label(res)
858 }
859
860 pub async fn load(&self, label: String) -> Result<InserterInner> {
861 self.request_builder.build_txn_inserter(label).await
862 }
863}