1use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use chrono_tz::Tz;
23use mysql_async::Opts;
24use mysql_async::prelude::Queryable;
25use risingwave_common::array::{Op, StreamChunk};
26use risingwave_common::catalog::Schema;
27use risingwave_common::types::DataType;
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30use serde_with::{DisplayFromStr, serde_as};
31use thiserror_ext::AsReport;
32use url::form_urlencoded;
33use uuid::Uuid;
34use with_options::WithOptions;
35
36use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
37use super::doris_starrocks_connector::{
38 HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
39 StarrocksTxnRequestBuilder,
40};
41use super::encoder::{JsonEncoder, RowEncoder};
42use super::{
43 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
44 SinkWriterMetrics,
45};
46use crate::enforce_secret::EnforceSecret;
47use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
48use crate::sink::writer::SinkWriter;
49use crate::sink::{Result, Sink, SinkWriterParam};
50
51pub const STARROCKS_SINK: &str = "starrocks";
52const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
53const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
54const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
55
56pub const fn _default_stream_load_http_timeout_ms() -> u64 {
57 30 * 1000
58}
59
60const fn default_use_https() -> bool {
61 false
62}
63
64#[serde_as]
65#[derive(Deserialize, Debug, Clone, WithOptions)]
66pub struct StarrocksCommon {
67 #[serde(rename = "starrocks.host")]
69 pub host: String,
70 #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
72 pub mysql_port: String,
73 #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
75 pub http_port: String,
76 #[serde(rename = "starrocks.user")]
78 pub user: String,
79 #[serde(rename = "starrocks.password")]
81 pub password: String,
82 #[serde(rename = "starrocks.database")]
84 pub database: String,
85 #[serde(rename = "starrocks.table")]
87 pub table: String,
88
89 #[serde(rename = "starrocks.use_https")]
91 #[serde(default = "default_use_https")]
92 #[serde_as(as = "DisplayFromStr")]
93 pub use_https: bool,
94}
95
96impl EnforceSecret for StarrocksCommon {
97 const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
98 "starrocks.password", "starrocks.user"
99 };
100}
101
102#[serde_as]
103#[derive(Clone, Debug, Deserialize, WithOptions)]
104pub struct StarrocksConfig {
105 #[serde(flatten)]
106 pub common: StarrocksCommon,
107
108 #[serde(
110 rename = "starrocks.stream_load.http.timeout.ms",
111 default = "_default_stream_load_http_timeout_ms"
112 )]
113 #[serde_as(as = "DisplayFromStr")]
114 #[with_option(allow_alter_on_fly)]
115 pub stream_load_http_timeout_ms: u64,
116
117 #[serde(default = "default_commit_checkpoint_interval")]
123 #[serde_as(as = "DisplayFromStr")]
124 #[with_option(allow_alter_on_fly)]
125 pub commit_checkpoint_interval: u64,
126
127 #[serde(rename = "starrocks.partial_update")]
129 pub partial_update: Option<String>,
130
131 pub r#type: String, }
133
134impl EnforceSecret for StarrocksConfig {
135 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
136 StarrocksCommon::enforce_one(prop)
137 }
138}
139
140fn default_commit_checkpoint_interval() -> u64 {
141 DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
142}
143
144impl StarrocksConfig {
145 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
146 let config =
147 serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
148 .map_err(|e| SinkError::Config(anyhow!(e)))?;
149 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
150 return Err(SinkError::Config(anyhow!(
151 "`{}` must be {}, or {}",
152 SINK_TYPE_OPTION,
153 SINK_TYPE_APPEND_ONLY,
154 SINK_TYPE_UPSERT
155 )));
156 }
157 if config.commit_checkpoint_interval == 0 {
158 return Err(SinkError::Config(anyhow!(
159 "`commit_checkpoint_interval` must be greater than 0"
160 )));
161 }
162 Ok(config)
163 }
164}
165
166#[derive(Debug)]
167pub struct StarrocksSink {
168 pub config: StarrocksConfig,
169 schema: Schema,
170 pk_indices: Vec<usize>,
171 is_append_only: bool,
172}
173
174impl EnforceSecret for StarrocksSink {
175 fn enforce_secret<'a>(
176 prop_iter: impl Iterator<Item = &'a str>,
177 ) -> crate::error::ConnectorResult<()> {
178 for prop in prop_iter {
179 StarrocksConfig::enforce_one(prop)?;
180 }
181 Ok(())
182 }
183}
184
185impl StarrocksSink {
186 pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
187 let pk_indices = param.downstream_pk_or_empty();
188 let is_append_only = param.sink_type.is_append_only();
189 Ok(Self {
190 config,
191 schema,
192 pk_indices,
193 is_append_only,
194 })
195 }
196}
197
198impl StarrocksSink {
199 fn starrocks_data_type_contains_any(
200 starrocks_data_type: &str,
201 expected_types: &[&str],
202 ) -> bool {
203 expected_types
204 .iter()
205 .any(|expected_type| starrocks_data_type.contains(expected_type))
206 }
207
208 fn check_column_name_and_type(
209 &self,
210 starrocks_columns_desc: HashMap<String, String>,
211 ) -> Result<()> {
212 let rw_fields_name = self.schema.fields();
213 if rw_fields_name.len() > starrocks_columns_desc.len() {
214 return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
215 }
216
217 for i in rw_fields_name {
218 let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
219 SinkError::Starrocks(format!(
220 "Column name don't find in starrocks, risingwave is {:?} ",
221 i.name
222 ))
223 })?;
224 if !Self::check_and_correct_column_type(&i.data_type, value)? {
225 return Err(SinkError::Starrocks(format!(
226 "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
227 i.name, value, i.data_type
228 )));
229 }
230 }
231 Ok(())
232 }
233
234 fn check_and_correct_column_type(
235 rw_data_type: &DataType,
236 starrocks_data_type: &str,
237 ) -> Result<bool> {
238 match rw_data_type {
239 risingwave_common::types::DataType::Boolean => {
240 Ok(Self::starrocks_data_type_contains_any(
241 starrocks_data_type,
242 &["tinyint", "boolean"],
243 ))
244 }
245 risingwave_common::types::DataType::Int16 => {
246 Ok(starrocks_data_type.contains("smallint"))
247 }
248 risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
249 risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
250 risingwave_common::types::DataType::Float32 => {
251 Ok(starrocks_data_type.contains("float"))
252 }
253 risingwave_common::types::DataType::Float64 => {
254 Ok(starrocks_data_type.contains("double"))
255 }
256 risingwave_common::types::DataType::Decimal => {
257 Ok(starrocks_data_type.contains("decimal"))
258 }
259 risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
260 risingwave_common::types::DataType::Varchar => {
261 Ok(starrocks_data_type.contains("varchar"))
262 }
263 risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
264 "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
265 )),
266 risingwave_common::types::DataType::Timestamp => {
267 Ok(starrocks_data_type.contains("datetime"))
268 }
269 risingwave_common::types::DataType::Timestamptz => {
270 Ok(Self::starrocks_data_type_contains_any(
273 starrocks_data_type,
274 &["datetime", "varchar", "char", "string"],
275 ))
276 }
277 risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
278 "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
279 )),
280 risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
281 "STRUCT is not supported for Starrocks sink.".to_owned(),
282 )),
283 risingwave_common::types::DataType::List(list) => {
284 if starrocks_data_type.contains("unknown") {
286 return Ok(true);
287 }
288 let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
289 Ok(check_result && starrocks_data_type.contains("array"))
290 }
291 risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
292 "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
293 )),
294 risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
295 risingwave_common::types::DataType::Serial => {
296 Ok(starrocks_data_type.contains("bigint"))
297 }
298 risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
299 "INT256 is not supported for Starrocks sink.".to_owned(),
300 )),
301 risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
302 "MAP is not supported for Starrocks sink.".to_owned(),
303 )),
304 DataType::Vector(_) => Err(SinkError::Starrocks(
305 "VECTOR is not supported for Starrocks sink.".to_owned(),
306 )),
307 }
308 }
309}
310
311impl Sink for StarrocksSink {
312 type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
313
314 const SINK_NAME: &'static str = STARROCKS_SINK;
315
316 async fn validate(&self) -> Result<()> {
317 if !self.is_append_only && self.pk_indices.is_empty() {
318 return Err(SinkError::Config(anyhow!(
319 "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
320 )));
321 }
322 let mut client = StarrocksSchemaClient::new(
324 self.config.common.host.clone(),
325 self.config.common.mysql_port.clone(),
326 self.config.common.table.clone(),
327 self.config.common.database.clone(),
328 self.config.common.user.clone(),
329 self.config.common.password.clone(),
330 )
331 .await?;
332 let (read_model, pks) = client.get_pk_from_starrocks().await?;
333
334 if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
335 return Err(SinkError::Config(anyhow!(
336 "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
337 )));
338 }
339
340 for (index, filed) in self.schema.fields().iter().enumerate() {
341 if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
342 return Err(SinkError::Starrocks(format!(
343 "Can't find pk {:?} in starrocks",
344 filed.name
345 )));
346 }
347 }
348
349 let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
350
351 self.check_column_name_and_type(starrocks_columns_desc)?;
352 Ok(())
353 }
354
355 fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
356 StarrocksConfig::from_btreemap(config.clone())?;
357 Ok(())
358 }
359
360 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
361 let commit_checkpoint_interval =
362 NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
363 "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
364 );
365
366 let writer = StarrocksSinkWriter::new(
367 self.config.clone(),
368 self.schema.clone(),
369 self.pk_indices.clone(),
370 self.is_append_only,
371 writer_param.time_zone,
372 )?;
373
374 let metrics = SinkWriterMetrics::new(&writer_param);
375
376 Ok(DecoupleCheckpointLogSinkerOf::new(
377 writer,
378 metrics,
379 commit_checkpoint_interval,
380 ))
381 }
382}
383
384pub struct StarrocksSinkWriter {
385 pub config: StarrocksConfig,
386 #[expect(dead_code)]
387 schema: Schema,
388 #[expect(dead_code)]
389 pk_indices: Vec<usize>,
390 is_append_only: bool,
391 client: Option<StarrocksClient>,
392 txn_client: Arc<StarrocksTxnClient>,
393 row_encoder: JsonEncoder,
394 curr_txn_label: Option<String>,
395}
396
397impl TryFrom<SinkParam> for StarrocksSink {
398 type Error = SinkError;
399
400 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
401 let schema = param.schema();
402 let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
403 StarrocksSink::new(param, config, schema)
404 }
405}
406
407impl StarrocksSinkWriter {
408 pub fn new(
409 config: StarrocksConfig,
410 schema: Schema,
411 pk_indices: Vec<usize>,
412 is_append_only: bool,
413 time_zone: Tz,
414 ) -> Result<Self> {
415 let mut field_names = schema.names_str();
416 if !is_append_only {
417 field_names.push(STARROCKS_DELETE_SIGN);
418 };
419 let field_names = field_names
422 .into_iter()
423 .map(|name| format!("`{}`", name))
424 .collect::<Vec<String>>();
425 let field_names_str = field_names
426 .iter()
427 .map(|name| name.as_str())
428 .collect::<Vec<&str>>();
429
430 let header = HeaderBuilder::new()
431 .add_common_header()
432 .set_user_password(config.common.user.clone(), config.common.password.clone())
433 .add_json_format()
434 .set_partial_update(config.partial_update.clone())
435 .set_columns_name(field_names_str)
436 .set_db(config.common.database.clone())
437 .set_table(config.common.table.clone())
438 .build();
439
440 let url = if config.common.use_https {
441 format!("https://{}:{}", config.common.host, config.common.http_port)
442 } else {
443 format!("http://{}:{}", config.common.host, config.common.http_port)
444 };
445 let txn_request_builder =
446 StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
447
448 Ok(Self {
449 config,
450 schema: schema.clone(),
451 pk_indices,
452 is_append_only,
453 client: None,
454 txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
455 row_encoder: JsonEncoder::new_with_starrocks(schema, None, time_zone),
456 curr_txn_label: None,
457 })
458 }
459
460 async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
461 for (op, row) in chunk.rows() {
462 if op != Op::Insert {
463 continue;
464 }
465 let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
466 self.client
467 .as_mut()
468 .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
469 .write(row_json_string.into())
470 .await?;
471 }
472 Ok(())
473 }
474
475 async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
476 for (op, row) in chunk.rows() {
477 match op {
478 Op::Insert => {
479 let mut row_json_value = self.row_encoder.encode(row)?;
480 row_json_value.insert(
481 STARROCKS_DELETE_SIGN.to_owned(),
482 Value::String("0".to_owned()),
483 );
484 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
485 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
486 })?;
487 self.client
488 .as_mut()
489 .ok_or_else(|| {
490 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
491 })?
492 .write(row_json_string.into())
493 .await?;
494 }
495 Op::Delete => {
496 let mut row_json_value = self.row_encoder.encode(row)?;
497 row_json_value.insert(
498 STARROCKS_DELETE_SIGN.to_owned(),
499 Value::String("1".to_owned()),
500 );
501 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
502 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
503 })?;
504 self.client
505 .as_mut()
506 .ok_or_else(|| {
507 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
508 })?
509 .write(row_json_string.into())
510 .await?;
511 }
512 Op::UpdateDelete => {}
513 Op::UpdateInsert => {
514 let mut row_json_value = self.row_encoder.encode(row)?;
515 row_json_value.insert(
516 STARROCKS_DELETE_SIGN.to_owned(),
517 Value::String("0".to_owned()),
518 );
519 let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
520 SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
521 })?;
522 self.client
523 .as_mut()
524 .ok_or_else(|| {
525 SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
526 })?
527 .write(row_json_string.into())
528 .await?;
529 }
530 }
531 }
532 Ok(())
533 }
534
535 #[inline(always)]
537 fn new_txn_label(&self) -> String {
538 format!(
539 "rw-txn-{}-{}",
540 Uuid::new_v4(),
541 chrono::Utc::now().timestamp_micros()
542 )
543 }
544
545 async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
546 tracing::debug!(?txn_label, "prepare transaction");
547 let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
548 if txn_label != txn_label_res {
549 return Err(SinkError::Starrocks(format!(
550 "label {} returned from prepare transaction {} differs from the current one",
551 txn_label, txn_label_res
552 )));
553 }
554 tracing::debug!(?txn_label, "commit transaction");
555 let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
556 if txn_label != txn_label_res {
557 return Err(SinkError::Starrocks(format!(
558 "label {} returned from commit transaction {} differs from the current one",
559 txn_label, txn_label_res
560 )));
561 }
562 Ok(())
563 }
564}
565
566impl Drop for StarrocksSinkWriter {
567 fn drop(&mut self) {
568 if let Some(txn_label) = self.curr_txn_label.take() {
569 let txn_client = self.txn_client.clone();
570 tokio::spawn(async move {
571 if let Err(e) = txn_client.rollback(txn_label.clone()).await {
572 tracing::error!(
573 "starrocks rollback transaction error: {:?}, txn label: {}",
574 e.as_report(),
575 txn_label
576 );
577 }
578 });
579 }
580 }
581}
582
583#[async_trait]
584impl SinkWriter for StarrocksSinkWriter {
585 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
586 Ok(())
587 }
588
589 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
590 if self.curr_txn_label.is_none() {
594 let txn_label = self.new_txn_label();
595 tracing::debug!(?txn_label, "begin transaction");
596 let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
597 if txn_label != txn_label_res {
598 return Err(SinkError::Starrocks(format!(
599 "label {} returned from StarRocks {} differs from generated one",
600 txn_label, txn_label_res
601 )));
602 }
603 self.curr_txn_label = Some(txn_label.clone());
604 }
605 if self.client.is_none() {
606 let txn_label = self.curr_txn_label.clone();
607 self.client = Some(StarrocksClient::new(
608 self.txn_client.load(txn_label.unwrap()).await?,
609 ));
610 }
611 if self.is_append_only {
612 self.append_only(chunk).await
613 } else {
614 self.upsert(chunk).await
615 }
616 }
617
618 async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
619 if let Some(client) = self.client.take() {
620 client.finish().await?;
626 }
627
628 if is_checkpoint
629 && let Some(txn_label) = self.curr_txn_label.take()
630 && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
631 {
632 match self.txn_client.rollback(txn_label.clone()).await {
633 Ok(_) => tracing::warn!(
634 ?txn_label,
635 "transaction is successfully rolled back due to commit failure"
636 ),
637 Err(err) => {
638 tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
639 }
640 }
641
642 return Err(err);
643 }
644 Ok(())
645 }
646
647 async fn abort(&mut self) -> Result<()> {
648 if let Some(txn_label) = self.curr_txn_label.take() {
649 tracing::debug!(?txn_label, "rollback transaction");
650 self.txn_client.rollback(txn_label).await?;
651 }
652 Ok(())
653 }
654}
655
656pub struct StarrocksSchemaClient {
657 table: String,
658 db: String,
659 conn: mysql_async::Conn,
660}
661
662impl StarrocksSchemaClient {
663 pub async fn new(
664 host: String,
665 port: String,
666 table: String,
667 db: String,
668 user: String,
669 password: String,
670 ) -> Result<Self> {
671 let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
674 let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
675
676 let conn_uri = format!(
677 "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
678 user,
679 password,
680 host,
681 port,
682 db,
683 STARROCK_MYSQL_PREFER_SOCKET,
684 STARROCK_MYSQL_MAX_ALLOWED_PACKET,
685 STARROCK_MYSQL_WAIT_TIMEOUT
686 );
687 let pool = mysql_async::Pool::new(
688 Opts::from_url(&conn_uri)
689 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
690 );
691 let conn = pool
692 .get_conn()
693 .await
694 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
695
696 Ok(Self { table, db, conn })
697 }
698
699 pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
700 let query = format!(
701 "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
702 self.table, self.db
703 );
704 let mut query_map: HashMap<String, String> = HashMap::default();
705 self.conn
706 .query_map(query, |(column_name, column_type)| {
707 query_map.insert(column_name, column_type)
708 })
709 .await
710 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
711 Ok(query_map)
712 }
713
714 pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
715 let query = format!(
716 "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
717 self.table, self.db
718 );
719 let table_mode_pk: (String, String) = self
720 .conn
721 .query_map(
722 query,
723 |(table_model, primary_key, sort_key): (String, String, String)| match table_model
724 .as_str()
725 {
726 "AGG_KEYS" => (table_model, sort_key),
730 _ => (table_model, primary_key),
731 },
732 )
733 .await
734 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
735 .first()
736 .ok_or_else(|| {
737 SinkError::Starrocks(format!(
738 "Can't find schema for StarRocks table {:?} in database {:?}. Please check that the table exists and the StarRocks user has write permission on the target table.",
739 self.table, self.db
740 ))
741 })?
742 .clone();
743 Ok(table_mode_pk)
744 }
745}
746
747#[derive(Debug, Serialize, Deserialize)]
748pub struct StarrocksInsertResultResponse {
749 #[serde(rename = "TxnId")]
750 pub txn_id: Option<i64>,
751 #[serde(rename = "Seq")]
752 pub seq: Option<i64>,
753 #[serde(rename = "Label")]
754 pub label: Option<String>,
755 #[serde(rename = "Status")]
756 pub status: String,
757 #[serde(rename = "Message")]
758 pub message: String,
759 #[serde(rename = "NumberTotalRows")]
760 pub number_total_rows: Option<i64>,
761 #[serde(rename = "NumberLoadedRows")]
762 pub number_loaded_rows: Option<i64>,
763 #[serde(rename = "NumberFilteredRows")]
764 pub number_filtered_rows: Option<i32>,
765 #[serde(rename = "NumberUnselectedRows")]
766 pub number_unselected_rows: Option<i32>,
767 #[serde(rename = "LoadBytes")]
768 pub load_bytes: Option<i64>,
769 #[serde(rename = "LoadTimeMs")]
770 pub load_time_ms: Option<i32>,
771 #[serde(rename = "BeginTxnTimeMs")]
772 pub begin_txn_time_ms: Option<i32>,
773 #[serde(rename = "ReadDataTimeMs")]
774 pub read_data_time_ms: Option<i32>,
775 #[serde(rename = "WriteDataTimeMs")]
776 pub write_data_time_ms: Option<i32>,
777 #[serde(rename = "CommitAndPublishTimeMs")]
778 pub commit_and_publish_time_ms: Option<i32>,
779 #[serde(rename = "StreamLoadPlanTimeMs")]
780 pub stream_load_plan_time_ms: Option<i32>,
781 #[serde(rename = "ExistingJobStatus")]
782 pub existing_job_status: Option<String>,
783 #[serde(rename = "ErrorURL")]
784 pub error_url: Option<String>,
785}
786
787pub struct StarrocksClient {
788 insert: InserterInner,
789}
790impl StarrocksClient {
791 pub fn new(insert: InserterInner) -> Self {
792 Self { insert }
793 }
794
795 pub async fn write(&mut self, data: Bytes) -> Result<()> {
796 self.insert.write(data).await?;
797 Ok(())
798 }
799
800 pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
801 let raw = self.insert.finish().await?;
802 let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
803 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
804
805 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
806 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
807 "Insert error: {}, {}, {:?}",
808 res.status,
809 res.message,
810 res.error_url,
811 )));
812 };
813 Ok(res)
814 }
815}
816
817pub struct StarrocksTxnClient {
818 request_builder: StarrocksTxnRequestBuilder,
819}
820
821impl StarrocksTxnClient {
822 pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
823 Self { request_builder }
824 }
825
826 fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
827 let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
828 .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
829 if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
830 return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
831 "transaction error: {}, {}, {:?}",
832 res.status,
833 res.message,
834 res.error_url,
835 )));
836 }
837 res.label.ok_or_else(|| {
838 SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
839 })
840 }
841
842 pub async fn begin(&self, label: String) -> Result<String> {
843 let res = self
844 .request_builder
845 .build_begin_request_sender(label)?
846 .send()
847 .await?;
848 self.check_response_and_extract_label(res)
849 }
850
851 pub async fn prepare(&self, label: String) -> Result<String> {
852 let res = self
853 .request_builder
854 .build_prepare_request_sender(label)?
855 .send()
856 .await?;
857 self.check_response_and_extract_label(res)
858 }
859
860 pub async fn commit(&self, label: String) -> Result<String> {
861 let res = self
862 .request_builder
863 .build_commit_request_sender(label)?
864 .send()
865 .await?;
866 self.check_response_and_extract_label(res)
867 }
868
869 pub async fn rollback(&self, label: String) -> Result<String> {
870 let res = self
871 .request_builder
872 .build_rollback_request_sender(label)?
873 .send()
874 .await?;
875 self.check_response_and_extract_label(res)
876 }
877
878 pub async fn load(&self, label: String) -> Result<InserterInner> {
879 self.request_builder.build_txn_inserter(label).await
880 }
881}
882
883#[cfg(test)]
884mod tests {
885 use risingwave_common::types::DataType;
886
887 use super::*;
888
889 fn is_compatible(rw_data_type: DataType, starrocks_data_type: &str) -> bool {
890 StarrocksSink::check_and_correct_column_type(&rw_data_type, starrocks_data_type).unwrap()
891 }
892
893 #[test]
894 fn test_timestamptz_compatible_starrocks_types() {
895 for starrocks_data_type in [
896 "datetime",
897 "datetime(6)",
898 "varchar(64)",
899 "char(32)",
900 "string",
901 ] {
902 assert!(
903 is_compatible(DataType::Timestamptz, starrocks_data_type),
904 "{starrocks_data_type} should be compatible with timestamptz"
905 );
906 }
907 }
908
909 #[test]
910 fn test_timestamptz_incompatible_starrocks_types() {
911 for starrocks_data_type in ["date", "int", "bigint", "json", "boolean"] {
912 assert!(
913 !is_compatible(DataType::Timestamptz, starrocks_data_type),
914 "{starrocks_data_type} should not be compatible with timestamptz"
915 );
916 }
917 }
918
919}