1use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::Deserialize;
28use serde_derive::Serialize;
29use serde_json::Value;
30use serde_with::{DisplayFromStr, serde_as};
31use thiserror_ext::AsReport;
32use url::form_urlencoded;
33use with_options::WithOptions;
34
35use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
36use super::doris_starrocks_connector::{
37 HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
38 StarrocksTxnRequestBuilder,
39};
40use super::encoder::{JsonEncoder, RowEncoder};
41use super::{
42 DummySinkCommitCoordinator, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
43 SinkError, SinkParam, SinkWriterMetrics,
44};
45use crate::enforce_secret::EnforceSecret;
46use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
47use crate::sink::{Result, Sink, SinkWriter, SinkWriterParam};
48
49pub const STARROCKS_SINK: &str = "starrocks";
50const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
51const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
52const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
53
54const fn _default_stream_load_http_timeout_ms() -> u64 {
55 30 * 1000
56}
57
58const fn default_use_https() -> bool {
59 false
60}
61
62#[derive(Deserialize, Debug, Clone, WithOptions)]
63pub struct StarrocksCommon {
64 #[serde(rename = "starrocks.host")]
66 pub host: String,
67 #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
69 pub mysql_port: String,
70 #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
72 pub http_port: String,
73 #[serde(rename = "starrocks.user")]
75 pub user: String,
76 #[serde(rename = "starrocks.password")]
78 pub password: String,
79 #[serde(rename = "starrocks.database")]
81 pub database: String,
82 #[serde(rename = "starrocks.table")]
84 pub table: String,
85
86 #[serde(rename = "starrocks.use_https")]
88 #[serde(default = "default_use_https")]
89 pub use_https: bool,
90}
91
92impl EnforceSecret for StarrocksCommon {
93 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
94 "starrocks.password", "starrocks.user"
95 };
96}
97
98#[serde_as]
99#[derive(Clone, Debug, Deserialize, WithOptions)]
100pub struct StarrocksConfig {
101 #[serde(flatten)]
102 pub common: StarrocksCommon,
103
104 #[serde(
106 rename = "starrocks.stream_load.http.timeout.ms",
107 default = "_default_stream_load_http_timeout_ms"
108 )]
109 #[serde_as(as = "DisplayFromStr")]
110 pub stream_load_http_timeout_ms: u64,
111
112 #[serde(default = "default_commit_checkpoint_interval")]
118 #[serde_as(as = "DisplayFromStr")]
119 #[with_option(allow_alter_on_fly)]
120 pub commit_checkpoint_interval: u64,
121
122 #[serde(rename = "starrocks.partial_update")]
124 pub partial_update: Option<String>,
125
126 pub r#type: String, }
128
129impl EnforceSecret for StarrocksConfig {
130 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
131 StarrocksCommon::enforce_one(prop)
132 }
133}
134
135fn default_commit_checkpoint_interval() -> u64 {
136 DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
137}
138
139impl StarrocksConfig {
140 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
141 let config =
142 serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
143 .map_err(|e| SinkError::Config(anyhow!(e)))?;
144 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
145 return Err(SinkError::Config(anyhow!(
146 "`{}` must be {}, or {}",
147 SINK_TYPE_OPTION,
148 SINK_TYPE_APPEND_ONLY,
149 SINK_TYPE_UPSERT
150 )));
151 }
152 if config.commit_checkpoint_interval == 0 {
153 return Err(SinkError::Config(anyhow!(
154 "`commit_checkpoint_interval` must be greater than 0"
155 )));
156 }
157 Ok(config)
158 }
159}
160
161#[derive(Debug)]
162pub struct StarrocksSink {
163 pub config: StarrocksConfig,
164 schema: Schema,
165 pk_indices: Vec<usize>,
166 is_append_only: bool,
167}
168
169impl EnforceSecret for StarrocksSink {
170 fn enforce_secret<'a>(
171 prop_iter: impl Iterator<Item = &'a str>,
172 ) -> crate::error::ConnectorResult<()> {
173 for prop in prop_iter {
174 StarrocksConfig::enforce_one(prop)?;
175 }
176 Ok(())
177 }
178}
179
180impl StarrocksSink {
181 pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
182 let pk_indices = param.downstream_pk.clone();
183 let is_append_only = param.sink_type.is_append_only();
184 Ok(Self {
185 config,
186 schema,
187 pk_indices,
188 is_append_only,
189 })
190 }
191}
192
193impl StarrocksSink {
194 fn check_column_name_and_type(
195 &self,
196 starrocks_columns_desc: HashMap<String, String>,
197 ) -> Result<()> {
198 let rw_fields_name = self.schema.fields();
199 if rw_fields_name.len() > starrocks_columns_desc.len() {
200 return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
201 }
202
203 for i in rw_fields_name {
204 let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
205 SinkError::Starrocks(format!(
206 "Column name don't find in starrocks, risingwave is {:?} ",
207 i.name
208 ))
209 })?;
210 if !Self::check_and_correct_column_type(&i.data_type, value)? {
211 return Err(SinkError::Starrocks(format!(
212 "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
213 i.name, value, i.data_type
214 )));
215 }
216 }
217 Ok(())
218 }
219
220 fn check_and_correct_column_type(
221 rw_data_type: &DataType,
222 starrocks_data_type: &String,
223 ) -> Result<bool> {
224 match rw_data_type {
225 risingwave_common::types::DataType::Boolean => {
226 Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
227 }
228 risingwave_common::types::DataType::Int16 => {
229 Ok(starrocks_data_type.contains("smallint"))
230 }
231 risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
232 risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
233 risingwave_common::types::DataType::Float32 => {
234 Ok(starrocks_data_type.contains("float"))
235 }
236 risingwave_common::types::DataType::Float64 => {
237 Ok(starrocks_data_type.contains("double"))
238 }
239 risingwave_common::types::DataType::Decimal => {
240 Ok(starrocks_data_type.contains("decimal"))
241 }
242 risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
243 risingwave_common::types::DataType::Varchar => {
244 Ok(starrocks_data_type.contains("varchar"))
245 }
246 risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
247 "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
248 )),
249 risingwave_common::types::DataType::Timestamp => {
250 Ok(starrocks_data_type.contains("datetime"))
251 }
252 risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
253 "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
254 )),
255 risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
256 "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
257 )),
258 risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
259 "STRUCT is not supported for Starrocks sink.".to_owned(),
260 )),
261 risingwave_common::types::DataType::List(list) => {
262 if starrocks_data_type.contains("unknown") {
264 return Ok(true);
265 }
266 let check_result = Self::check_and_correct_column_type(list.as_ref(), starrocks_data_type)?;
267 Ok(check_result && starrocks_data_type.contains("array"))
268 }
269 risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
270 "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
271 )),
272 risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
273 risingwave_common::types::DataType::Serial => {
274 Ok(starrocks_data_type.contains("bigint"))
275 }
276 risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
277 "INT256 is not supported for Starrocks sink.".to_owned(),
278 )),
279 risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
280 "MAP is not supported for Starrocks sink.".to_owned(),
281 )),
282 DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
283 }
284 }
285}
286
287impl Sink for StarrocksSink {
288 type Coordinator = DummySinkCommitCoordinator;
289 type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
290
291 const SINK_NAME: &'static str = STARROCKS_SINK;
292
293 async fn validate(&self) -> Result<()> {
294 if !self.is_append_only && self.pk_indices.is_empty() {
295 return Err(SinkError::Config(anyhow!(
296 "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
297 )));
298 }
299 let mut client = StarrocksSchemaClient::new(
301 self.config.common.host.clone(),
302 self.config.common.mysql_port.clone(),
303 self.config.common.table.clone(),
304 self.config.common.database.clone(),
305 self.config.common.user.clone(),
306 self.config.common.password.clone(),
307 )
308 .await?;
309 let (read_model, pks) = client.get_pk_from_starrocks().await?;
310
311 if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
312 return Err(SinkError::Config(anyhow!(
313 "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
314 )));
315 }
316
317 for (index, filed) in self.schema.fields().iter().enumerate() {
318 if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
319 return Err(SinkError::Starrocks(format!(
320 "Can't find pk {:?} in starrocks",
321 filed.name
322 )));
323 }
324 }
325
326 let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
327
328 self.check_column_name_and_type(starrocks_columns_desc)?;
329 Ok(())
330 }
331
332 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
333 StarrocksConfig::from_btreemap(config.clone())?;
334 Ok(())
335 }
336
337 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
338 let commit_checkpoint_interval =
339 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
340 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
341 );
342
343 let writer = StarrocksSinkWriter::new(
344 self.config.clone(),
345 self.schema.clone(),
346 self.pk_indices.clone(),
347 self.is_append_only,
348 writer_param.executor_id,
349 )?;
350
351 let metrics = SinkWriterMetrics::new(&writer_param);
352
353 Ok(DecoupleCheckpointLogSinkerOf::new(
354 writer,
355 metrics,
356 commit_checkpoint_interval,
357 ))
358 }
359}
360
361pub struct StarrocksSinkWriter {
362 pub config: StarrocksConfig,
363 #[expect(dead_code)]
364 schema: Schema,
365 #[expect(dead_code)]
366 pk_indices: Vec<usize>,
367 is_append_only: bool,
368 client: Option<StarrocksClient>,
369 txn_client: Arc<StarrocksTxnClient>,
370 row_encoder: JsonEncoder,
371 executor_id: u64,
372 curr_txn_label: Option<String>,
373}
374
375impl TryFrom<SinkParam> for StarrocksSink {
376 type Error = SinkError;
377
378 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
379 let schema = param.schema();
380 let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
381 StarrocksSink::new(param, config, schema)
382 }
383}
384
385impl StarrocksSinkWriter {
386 pub fn new(
387 config: StarrocksConfig,
388 schema: Schema,
389 pk_indices: Vec<usize>,
390 is_append_only: bool,
391 executor_id: u64,
392 ) -> Result<Self> {
393 let mut field_names = schema.names_str();
394 if !is_append_only {
395 field_names.push(STARROCKS_DELETE_SIGN);
396 };
397 let field_names = field_names
400 .into_iter()
401 .map(|name| format!("`{}`", name))
402 .collect::<Vec<String>>();
403 let field_names_str = field_names
404 .iter()
405 .map(|name| name.as_str())
406 .collect::<Vec<&str>>();
407
408 let header = HeaderBuilder::new()
409 .add_common_header()
410 .set_user_password(config.common.user.clone(), config.common.password.clone())
411 .add_json_format()
412 .set_partial_update(config.partial_update.clone())
413 .set_columns_name(field_names_str)
414 .set_db(config.common.database.clone())
415 .set_table(config.common.table.clone())
416 .build();
417
418 let url = if config.common.use_https {
419 format!("https://{}:{}", config.common.host, config.common.http_port)
420 } else {
421 format!("http://{}:{}", config.common.host, config.common.http_port)
422 };
423 let txn_request_builder =
424 StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
425
426 Ok(Self {
427 config,
428 schema: schema.clone(),
429 pk_indices,
430 is_append_only,
431 client: None,
432 txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
433 row_encoder: JsonEncoder::new_with_starrocks(schema, None),
434 executor_id,
435 curr_txn_label: None,
436 })
437 }
438
439 async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
440 for (op, row) in chunk.rows() {
441 if op != Op::Insert {
442 continue;
443 }
444 let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
445 self.client
446 .as_mut()
447 .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
448 .write(row_json_string.into())
449 .await?;
450 }
451 Ok(())
452 }
453
454 async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
455 for (op, row) in chunk.rows() {
456 match op {
457 Op::Insert => {
458 let mut row_json_value = self.row_encoder.encode(row)?;
459 row_json_value.insert(
460 STARROCKS_DELETE_SIGN.to_owned(),
461 Value::String("0".to_owned()),
462 );
463 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
464 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
465 })?;
466 self.client
467 .as_mut()
468 .ok_or_else(|| {
469 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
470 })?
471 .write(row_json_string.into())
472 .await?;
473 }
474 Op::Delete => {
475 let mut row_json_value = self.row_encoder.encode(row)?;
476 row_json_value.insert(
477 STARROCKS_DELETE_SIGN.to_owned(),
478 Value::String("1".to_owned()),
479 );
480 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
481 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
482 })?;
483 self.client
484 .as_mut()
485 .ok_or_else(|| {
486 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
487 })?
488 .write(row_json_string.into())
489 .await?;
490 }
491 Op::UpdateDelete => {}
492 Op::UpdateInsert => {
493 let mut row_json_value = self.row_encoder.encode(row)?;
494 row_json_value.insert(
495 STARROCKS_DELETE_SIGN.to_owned(),
496 Value::String("0".to_owned()),
497 );
498 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
499 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
500 })?;
501 self.client
502 .as_mut()
503 .ok_or_else(|| {
504 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
505 })?
506 .write(row_json_string.into())
507 .await?;
508 }
509 }
510 }
511 Ok(())
512 }
513
514 #[inline(always)]
516 fn new_txn_label(&self) -> String {
517 format!(
518 "rw-txn-{}-{}",
519 self.executor_id,
520 chrono::Utc::now().timestamp_micros()
521 )
522 }
523
524 async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
525 tracing::debug!(?txn_label, "prepare transaction");
526 let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
527 if txn_label != txn_label_res {
528 return Err(SinkError::Starrocks(format!(
529 "label {} returned from prepare transaction {} differs from the current one",
530 txn_label, txn_label_res
531 )));
532 }
533 tracing::debug!(?txn_label, "commit transaction");
534 let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
535 if txn_label != txn_label_res {
536 return Err(SinkError::Starrocks(format!(
537 "label {} returned from commit transaction {} differs from the current one",
538 txn_label, txn_label_res
539 )));
540 }
541 Ok(())
542 }
543}
544
545impl Drop for StarrocksSinkWriter {
546 fn drop(&mut self) {
547 if let Some(txn_label) = self.curr_txn_label.take() {
548 let txn_client = self.txn_client.clone();
549 tokio::spawn(async move {
550 if let Err(e) = txn_client.rollback(txn_label.clone()).await {
551 tracing::error!(
552 "starrocks rollback transaction error: {:?}, txn label: {}",
553 e.as_report(),
554 txn_label
555 );
556 }
557 });
558 }
559 }
560}
561
562#[async_trait]
563impl SinkWriter for StarrocksSinkWriter {
564 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
565 Ok(())
566 }
567
568 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
569 if self.curr_txn_label.is_none() {
573 let txn_label = self.new_txn_label();
574 tracing::debug!(?txn_label, "begin transaction");
575 let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
576 if txn_label != txn_label_res {
577 return Err(SinkError::Starrocks(format!(
578 "label {} returned from StarRocks {} differs from generated one",
579 txn_label, txn_label_res
580 )));
581 }
582 self.curr_txn_label = Some(txn_label.clone());
583 }
584 if self.client.is_none() {
585 let txn_label = self.curr_txn_label.clone();
586 self.client = Some(StarrocksClient::new(
587 self.txn_client.load(txn_label.unwrap()).await?,
588 ));
589 }
590 if self.is_append_only {
591 self.append_only(chunk).await
592 } else {
593 self.upsert(chunk).await
594 }
595 }
596
597 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
598 if let Some(client) = self.client.take() {
599 client.finish().await?;
605 }
606
607 if is_checkpoint
608 && let Some(txn_label) = self.curr_txn_label.take()
609 && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
610 {
611 match self.txn_client.rollback(txn_label.clone()).await {
612 Ok(_) => tracing::warn!(
613 ?txn_label,
614 "transaction is successfully rolled back due to commit failure"
615 ),
616 Err(err) => {
617 tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
618 }
619 }
620
621 return Err(err);
622 }
623 Ok(())
624 }
625
626 async fn abort(&mut self) -> Result<()> {
627 if let Some(txn_label) = self.curr_txn_label.take() {
628 tracing::debug!(?txn_label, "rollback transaction");
629 self.txn_client.rollback(txn_label).await?;
630 }
631 Ok(())
632 }
633}
634
635pub struct StarrocksSchemaClient {
636 table: String,
637 db: String,
638 conn: mysql_async::Conn,
639}
640
641impl StarrocksSchemaClient {
642 pub async fn new(
643 host: String,
644 port: String,
645 table: String,
646 db: String,
647 user: String,
648 password: String,
649 ) -> Result<Self> {
650 let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
653 let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
654
655 let conn_uri = format!(
656 "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
657 user,
658 password,
659 host,
660 port,
661 db,
662 STARROCK_MYSQL_PREFER_SOCKET,
663 STARROCK_MYSQL_MAX_ALLOWED_PACKET,
664 STARROCK_MYSQL_WAIT_TIMEOUT
665 );
666 let pool = mysql_async::Pool::new(
667 Opts::from_url(&conn_uri)
668 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
669 );
670 let conn = pool
671 .get_conn()
672 .await
673 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
674
675 Ok(Self { table, db, conn })
676 }
677
678 pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
679 let query = format!(
680 "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
681 self.table, self.db
682 );
683 let mut query_map: HashMap<String, String> = HashMap::default();
684 self.conn
685 .query_map(query, |(column_name, column_type)| {
686 query_map.insert(column_name, column_type)
687 })
688 .await
689 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
690 Ok(query_map)
691 }
692
693 pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
694 let query = format!(
695 "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
696 self.table, self.db
697 );
698 let table_mode_pk: (String, String) = self
699 .conn
700 .query_map(
701 query,
702 |(table_model, primary_key, sort_key): (String, String, String)| match table_model
703 .as_str()
704 {
705 "AGG_KEYS" => (table_model, sort_key),
709 _ => (table_model, primary_key),
710 },
711 )
712 .await
713 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
714 .first()
715 .ok_or_else(|| {
716 SinkError::Starrocks(format!(
717 "Can't find schema with table {:?} and database {:?}",
718 self.table, self.db
719 ))
720 })?
721 .clone();
722 Ok(table_mode_pk)
723 }
724}
725
726#[derive(Debug, Serialize, Deserialize)]
727pub struct StarrocksInsertResultResponse {
728 #[serde(rename = "TxnId")]
729 pub txn_id: Option<i64>,
730 #[serde(rename = "Seq")]
731 pub seq: Option<i64>,
732 #[serde(rename = "Label")]
733 pub label: Option<String>,
734 #[serde(rename = "Status")]
735 pub status: String,
736 #[serde(rename = "Message")]
737 pub message: String,
738 #[serde(rename = "NumberTotalRows")]
739 pub number_total_rows: Option<i64>,
740 #[serde(rename = "NumberLoadedRows")]
741 pub number_loaded_rows: Option<i64>,
742 #[serde(rename = "NumberFilteredRows")]
743 pub number_filtered_rows: Option<i32>,
744 #[serde(rename = "NumberUnselectedRows")]
745 pub number_unselected_rows: Option<i32>,
746 #[serde(rename = "LoadBytes")]
747 pub load_bytes: Option<i64>,
748 #[serde(rename = "LoadTimeMs")]
749 pub load_time_ms: Option<i32>,
750 #[serde(rename = "BeginTxnTimeMs")]
751 pub begin_txn_time_ms: Option<i32>,
752 #[serde(rename = "ReadDataTimeMs")]
753 pub read_data_time_ms: Option<i32>,
754 #[serde(rename = "WriteDataTimeMs")]
755 pub write_data_time_ms: Option<i32>,
756 #[serde(rename = "CommitAndPublishTimeMs")]
757 pub commit_and_publish_time_ms: Option<i32>,
758 #[serde(rename = "StreamLoadPlanTimeMs")]
759 pub stream_load_plan_time_ms: Option<i32>,
760 #[serde(rename = "ExistingJobStatus")]
761 pub existing_job_status: Option<String>,
762 #[serde(rename = "ErrorURL")]
763 pub error_url: Option<String>,
764}
765
766pub struct StarrocksClient {
767 insert: InserterInner,
768}
769impl StarrocksClient {
770 pub fn new(insert: InserterInner) -> Self {
771 Self { insert }
772 }
773
774 pub async fn write(&mut self, data: Bytes) -> Result<()> {
775 self.insert.write(data).await?;
776 Ok(())
777 }
778
779 pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
780 let raw = self.insert.finish().await?;
781 let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
782 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
783
784 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
785 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
786 "Insert error: {}, {}, {:?}",
787 res.status,
788 res.message,
789 res.error_url,
790 )));
791 };
792 Ok(res)
793 }
794}
795
796pub struct StarrocksTxnClient {
797 request_builder: StarrocksTxnRequestBuilder,
798}
799
800impl StarrocksTxnClient {
801 pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
802 Self { request_builder }
803 }
804
805 fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
806 let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
807 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
808 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
809 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
810 "transaction error: {}, {}, {:?}",
811 res.status,
812 res.message,
813 res.error_url,
814 )));
815 }
816 res.label.ok_or_else(|| {
817 SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
818 })
819 }
820
821 pub async fn begin(&self, label: String) -> Result<String> {
822 let res = self
823 .request_builder
824 .build_begin_request_sender(label)?
825 .send()
826 .await?;
827 self.check_response_and_extract_label(res)
828 }
829
830 pub async fn prepare(&self, label: String) -> Result<String> {
831 let res = self
832 .request_builder
833 .build_prepare_request_sender(label)?
834 .send()
835 .await?;
836 self.check_response_and_extract_label(res)
837 }
838
839 pub async fn commit(&self, label: String) -> Result<String> {
840 let res = self
841 .request_builder
842 .build_commit_request_sender(label)?
843 .send()
844 .await?;
845 self.check_response_and_extract_label(res)
846 }
847
848 pub async fn rollback(&self, label: String) -> Result<String> {
849 let res = self
850 .request_builder
851 .build_rollback_request_sender(label)?
852 .send()
853 .await?;
854 self.check_response_and_extract_label(res)
855 }
856
857 pub async fn load(&self, label: String) -> Result<InserterInner> {
858 self.request_builder.build_txn_inserter(label).await
859 }
860}