risingwave_connector/sink/
starrocks.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use risingwave_pb::connector_service::SinkMetadata;
28use risingwave_pb::connector_service::sink_metadata::Metadata::Serialized;
29use risingwave_pb::connector_service::sink_metadata::SerializedMetadata;
30use sea_orm::DatabaseConnection;
31use serde::Deserialize;
32use serde_derive::Serialize;
33use serde_json::Value;
34use serde_with::{DisplayFromStr, serde_as};
35use thiserror_ext::AsReport;
36use url::form_urlencoded;
37use with_options::WithOptions;
38
39use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
40use super::doris_starrocks_connector::{
41    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
42    StarrocksTxnRequestBuilder,
43};
44use super::encoder::{JsonEncoder, RowEncoder};
45use super::{
46    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkCommitCoordinator,
47    SinkCommittedEpochSubscriber, SinkError, SinkParam,
48};
49use crate::sink::coordinate::CoordinatedLogSinker;
50use crate::sink::{Result, Sink, SinkWriter, SinkWriterParam};
51
52pub const STARROCKS_SINK: &str = "starrocks";
53const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
54const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
55const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
56
57const fn _default_stream_load_http_timeout_ms() -> u64 {
58    30 * 1000
59}
60
61#[derive(Deserialize, Debug, Clone, WithOptions)]
62pub struct StarrocksCommon {
63    /// The `StarRocks` host address.
64    #[serde(rename = "starrocks.host")]
65    pub host: String,
66    /// The port to the MySQL server of `StarRocks` FE.
67    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
68    pub mysql_port: String,
69    /// The port to the HTTP server of `StarRocks` FE.
70    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
71    pub http_port: String,
72    /// The user name used to access the `StarRocks` database.
73    #[serde(rename = "starrocks.user")]
74    pub user: String,
75    /// The password associated with the user.
76    #[serde(rename = "starrocks.password")]
77    pub password: String,
78    /// The `StarRocks` database where the target table is located
79    #[serde(rename = "starrocks.database")]
80    pub database: String,
81    /// The `StarRocks` table you want to sink data to.
82    #[serde(rename = "starrocks.table")]
83    pub table: String,
84}
85
86#[serde_as]
87#[derive(Clone, Debug, Deserialize, WithOptions)]
88pub struct StarrocksConfig {
89    #[serde(flatten)]
90    pub common: StarrocksCommon,
91
92    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
93    #[serde(
94        rename = "starrocks.stream_load.http.timeout.ms",
95        default = "_default_stream_load_http_timeout_ms"
96    )]
97    #[serde_as(as = "DisplayFromStr")]
98    pub stream_load_http_timeout_ms: u64,
99
100    /// Set this option to a positive integer n, RisingWave will try to commit data
101    /// to Starrocks at every n checkpoints by leveraging the
102    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
103    /// also, in this time, the `sink_decouple` option should be enabled as well.
104    /// Defaults to 10 if commit_checkpoint_interval <= 0
105    #[serde(default = "default_commit_checkpoint_interval")]
106    #[serde_as(as = "DisplayFromStr")]
107    pub commit_checkpoint_interval: u64,
108
109    /// Enable partial update
110    #[serde(rename = "starrocks.partial_update")]
111    pub partial_update: Option<String>,
112
113    pub r#type: String, // accept "append-only" or "upsert"
114}
115
116fn default_commit_checkpoint_interval() -> u64 {
117    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
118}
119
120impl StarrocksConfig {
121    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
122        let config =
123            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
124                .map_err(|e| SinkError::Config(anyhow!(e)))?;
125        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
126            return Err(SinkError::Config(anyhow!(
127                "`{}` must be {}, or {}",
128                SINK_TYPE_OPTION,
129                SINK_TYPE_APPEND_ONLY,
130                SINK_TYPE_UPSERT
131            )));
132        }
133        if config.commit_checkpoint_interval == 0 {
134            return Err(SinkError::Config(anyhow!(
135                "`commit_checkpoint_interval` must be greater than 0"
136            )));
137        }
138        Ok(config)
139    }
140}
141
142#[derive(Debug)]
143pub struct StarrocksSink {
144    param: SinkParam,
145    pub config: StarrocksConfig,
146    schema: Schema,
147    pk_indices: Vec<usize>,
148    is_append_only: bool,
149}
150
151impl StarrocksSink {
152    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
153        let pk_indices = param.downstream_pk.clone();
154        let is_append_only = param.sink_type.is_append_only();
155        Ok(Self {
156            param,
157            config,
158            schema,
159            pk_indices,
160            is_append_only,
161        })
162    }
163}
164
165impl StarrocksSink {
166    fn check_column_name_and_type(
167        &self,
168        starrocks_columns_desc: HashMap<String, String>,
169    ) -> Result<()> {
170        let rw_fields_name = self.schema.fields();
171        if rw_fields_name.len() > starrocks_columns_desc.len() {
172            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
173        }
174
175        for i in rw_fields_name {
176            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
177                SinkError::Starrocks(format!(
178                    "Column name don't find in starrocks, risingwave is {:?} ",
179                    i.name
180                ))
181            })?;
182            if !Self::check_and_correct_column_type(&i.data_type, value)? {
183                return Err(SinkError::Starrocks(format!(
184                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
185                    i.name, value, i.data_type
186                )));
187            }
188        }
189        Ok(())
190    }
191
192    fn check_and_correct_column_type(
193        rw_data_type: &DataType,
194        starrocks_data_type: &String,
195    ) -> Result<bool> {
196        match rw_data_type {
197            risingwave_common::types::DataType::Boolean => {
198                Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
199            }
200            risingwave_common::types::DataType::Int16 => {
201                Ok(starrocks_data_type.contains("smallint"))
202            }
203            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
204            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
205            risingwave_common::types::DataType::Float32 => {
206                Ok(starrocks_data_type.contains("float"))
207            }
208            risingwave_common::types::DataType::Float64 => {
209                Ok(starrocks_data_type.contains("double"))
210            }
211            risingwave_common::types::DataType::Decimal => {
212                Ok(starrocks_data_type.contains("decimal"))
213            }
214            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
215            risingwave_common::types::DataType::Varchar => {
216                Ok(starrocks_data_type.contains("varchar"))
217            }
218            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
219                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
220            )),
221            risingwave_common::types::DataType::Timestamp => {
222                Ok(starrocks_data_type.contains("datetime"))
223            }
224            risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
225                "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
226            )),
227            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
228                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
229            )),
230            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
231                "STRUCT is not supported for Starrocks sink.".to_owned(),
232            )),
233            risingwave_common::types::DataType::List(list) => {
234                // For compatibility with older versions starrocks
235                if starrocks_data_type.contains("unknown") {
236                    return Ok(true);
237                }
238                let check_result = Self::check_and_correct_column_type(list.as_ref(), starrocks_data_type)?;
239                Ok(check_result && starrocks_data_type.contains("array"))
240            }
241            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
242                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
243            )),
244            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
245            risingwave_common::types::DataType::Serial => {
246                Ok(starrocks_data_type.contains("bigint"))
247            }
248            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
249                "INT256 is not supported for Starrocks sink.".to_owned(),
250            )),
251            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
252                "MAP is not supported for Starrocks sink.".to_owned(),
253            )),
254        }
255    }
256}
257
258impl Sink for StarrocksSink {
259    type Coordinator = StarrocksSinkCommitter;
260    type LogSinker = CoordinatedLogSinker<StarrocksSinkWriter>;
261
262    const SINK_NAME: &'static str = STARROCKS_SINK;
263
264    async fn validate(&self) -> Result<()> {
265        if !self.is_append_only && self.pk_indices.is_empty() {
266            return Err(SinkError::Config(anyhow!(
267                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
268            )));
269        }
270        // check reachability
271        let mut client = StarrocksSchemaClient::new(
272            self.config.common.host.clone(),
273            self.config.common.mysql_port.clone(),
274            self.config.common.table.clone(),
275            self.config.common.database.clone(),
276            self.config.common.user.clone(),
277            self.config.common.password.clone(),
278        )
279        .await?;
280        let (read_model, pks) = client.get_pk_from_starrocks().await?;
281
282        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
283            return Err(SinkError::Config(anyhow!(
284                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
285            )));
286        }
287
288        for (index, filed) in self.schema.fields().iter().enumerate() {
289            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
290                return Err(SinkError::Starrocks(format!(
291                    "Can't find pk {:?} in starrocks",
292                    filed.name
293                )));
294            }
295        }
296
297        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
298
299        self.check_column_name_and_type(starrocks_columns_desc)?;
300        Ok(())
301    }
302
303    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
304        let commit_checkpoint_interval =
305            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
306                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
307            );
308
309        let inner = StarrocksSinkWriter::new(
310            self.config.clone(),
311            self.schema.clone(),
312            self.pk_indices.clone(),
313            self.is_append_only,
314            writer_param.executor_id,
315        )?;
316
317        let writer = CoordinatedLogSinker::new(
318            &writer_param,
319            self.param.clone(),
320            inner,
321            commit_checkpoint_interval,
322        )
323        .await?;
324        Ok(writer)
325    }
326
327    fn is_coordinated_sink(&self) -> bool {
328        true
329    }
330
331    async fn new_coordinator(&self, _db: DatabaseConnection) -> Result<Self::Coordinator> {
332        let header = HeaderBuilder::new()
333            .add_common_header()
334            .set_user_password(
335                self.config.common.user.clone(),
336                self.config.common.password.clone(),
337            )
338            .set_db(self.config.common.database.clone())
339            .set_table(self.config.common.table.clone())
340            .build();
341
342        let txn_request_builder = StarrocksTxnRequestBuilder::new(
343            format!(
344                "http://{}:{}",
345                self.config.common.host, self.config.common.http_port
346            ),
347            header,
348            self.config.stream_load_http_timeout_ms,
349        )?;
350        Ok(StarrocksSinkCommitter {
351            client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
352        })
353    }
354}
355
356pub struct StarrocksSinkWriter {
357    pub config: StarrocksConfig,
358    #[expect(dead_code)]
359    schema: Schema,
360    #[expect(dead_code)]
361    pk_indices: Vec<usize>,
362    is_append_only: bool,
363    client: Option<StarrocksClient>,
364    txn_client: Arc<StarrocksTxnClient>,
365    row_encoder: JsonEncoder,
366    executor_id: u64,
367    curr_txn_label: Option<String>,
368}
369
370impl TryFrom<SinkParam> for StarrocksSink {
371    type Error = SinkError;
372
373    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
374        let schema = param.schema();
375        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
376        StarrocksSink::new(param, config, schema)
377    }
378}
379
380impl StarrocksSinkWriter {
381    pub fn new(
382        config: StarrocksConfig,
383        schema: Schema,
384        pk_indices: Vec<usize>,
385        is_append_only: bool,
386        executor_id: u64,
387    ) -> Result<Self> {
388        let mut field_names = schema.names_str();
389        if !is_append_only {
390            field_names.push(STARROCKS_DELETE_SIGN);
391        };
392        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
393        // a field name being a reserved word. For example, `order`, 'from`, etc.
394        let field_names = field_names
395            .into_iter()
396            .map(|name| format!("`{}`", name))
397            .collect::<Vec<String>>();
398        let field_names_str = field_names
399            .iter()
400            .map(|name| name.as_str())
401            .collect::<Vec<&str>>();
402
403        let header = HeaderBuilder::new()
404            .add_common_header()
405            .set_user_password(config.common.user.clone(), config.common.password.clone())
406            .add_json_format()
407            .set_partial_update(config.partial_update.clone())
408            .set_columns_name(field_names_str)
409            .set_db(config.common.database.clone())
410            .set_table(config.common.table.clone())
411            .build();
412
413        let txn_request_builder = StarrocksTxnRequestBuilder::new(
414            format!("http://{}:{}", config.common.host, config.common.http_port),
415            header,
416            config.stream_load_http_timeout_ms,
417        )?;
418
419        Ok(Self {
420            config,
421            schema: schema.clone(),
422            pk_indices,
423            is_append_only,
424            client: None,
425            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
426            row_encoder: JsonEncoder::new_with_starrocks(schema, None),
427            executor_id,
428            curr_txn_label: None,
429        })
430    }
431
432    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
433        for (op, row) in chunk.rows() {
434            if op != Op::Insert {
435                continue;
436            }
437            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
438            self.client
439                .as_mut()
440                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
441                .write(row_json_string.into())
442                .await?;
443        }
444        Ok(())
445    }
446
447    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
448        for (op, row) in chunk.rows() {
449            match op {
450                Op::Insert => {
451                    let mut row_json_value = self.row_encoder.encode(row)?;
452                    row_json_value.insert(
453                        STARROCKS_DELETE_SIGN.to_owned(),
454                        Value::String("0".to_owned()),
455                    );
456                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
457                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
458                    })?;
459                    self.client
460                        .as_mut()
461                        .ok_or_else(|| {
462                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
463                        })?
464                        .write(row_json_string.into())
465                        .await?;
466                }
467                Op::Delete => {
468                    let mut row_json_value = self.row_encoder.encode(row)?;
469                    row_json_value.insert(
470                        STARROCKS_DELETE_SIGN.to_owned(),
471                        Value::String("1".to_owned()),
472                    );
473                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
474                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
475                    })?;
476                    self.client
477                        .as_mut()
478                        .ok_or_else(|| {
479                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
480                        })?
481                        .write(row_json_string.into())
482                        .await?;
483                }
484                Op::UpdateDelete => {}
485                Op::UpdateInsert => {
486                    let mut row_json_value = self.row_encoder.encode(row)?;
487                    row_json_value.insert(
488                        STARROCKS_DELETE_SIGN.to_owned(),
489                        Value::String("0".to_owned()),
490                    );
491                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
492                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
493                    })?;
494                    self.client
495                        .as_mut()
496                        .ok_or_else(|| {
497                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
498                        })?
499                        .write(row_json_string.into())
500                        .await?;
501                }
502            }
503        }
504        Ok(())
505    }
506
507    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
508    #[inline(always)]
509    fn new_txn_label(&self) -> String {
510        format!(
511            "rw-txn-{}-{}",
512            self.executor_id,
513            chrono::Utc::now().timestamp_micros()
514        )
515    }
516}
517
518impl Drop for StarrocksSinkWriter {
519    fn drop(&mut self) {
520        if let Some(txn_label) = self.curr_txn_label.take() {
521            let txn_client = self.txn_client.clone();
522            tokio::spawn(async move {
523                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
524                    tracing::error!(
525                        "starrocks rollback transaction error: {:?}, txn label: {}",
526                        e.as_report(),
527                        txn_label
528                    );
529                }
530            });
531        }
532    }
533}
534
535#[async_trait]
536impl SinkWriter for StarrocksSinkWriter {
537    type CommitMetadata = Option<SinkMetadata>;
538
539    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
540        Ok(())
541    }
542
543    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
544        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
545        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
546        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
547        if self.curr_txn_label.is_none() {
548            let txn_label = self.new_txn_label();
549            tracing::debug!(?txn_label, "begin transaction");
550            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
551            assert_eq!(
552                txn_label, txn_label_res,
553                "label responding from StarRocks: {} differ from generated one: {}",
554                txn_label, txn_label_res
555            );
556            self.curr_txn_label = Some(txn_label.clone());
557        }
558        if self.client.is_none() {
559            let txn_label = self.curr_txn_label.clone();
560            assert!(txn_label.is_some(), "transaction label is none during load");
561            self.client = Some(StarrocksClient::new(
562                self.txn_client.load(txn_label.unwrap()).await?,
563            ));
564        }
565        if self.is_append_only {
566            self.append_only(chunk).await
567        } else {
568            self.upsert(chunk).await
569        }
570    }
571
572    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
573        if self.client.is_some() {
574            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
575            // one or more load requests should be made within one commit_checkpoint_interval period.
576            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
577            // Thus, only one version will be produced when the transaction is committed. See Stream Load
578            // transaction interface for more information.
579            let client = self
580                .client
581                .take()
582                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks inserter".to_owned()))?;
583            client.finish().await?;
584        }
585
586        if is_checkpoint {
587            if self.curr_txn_label.is_some() {
588                let txn_label = self.curr_txn_label.take().unwrap();
589                tracing::debug!(?txn_label, "prepare transaction");
590                let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
591                assert_eq!(
592                    txn_label, txn_label_res,
593                    "label responding from StarRocks differs from the current one"
594                );
595                Ok(Some(StarrocksWriteResult(Some(txn_label)).try_into()?))
596            } else {
597                // no data was written within previous epoch
598                Ok(Some(StarrocksWriteResult(None).try_into()?))
599            }
600        } else {
601            Ok(None)
602        }
603    }
604
605    async fn abort(&mut self) -> Result<()> {
606        if self.curr_txn_label.is_some() {
607            let txn_label = self.curr_txn_label.take().unwrap();
608            tracing::debug!(?txn_label, "rollback transaction");
609            self.txn_client.rollback(txn_label).await?;
610        }
611        Ok(())
612    }
613}
614
615pub struct StarrocksSchemaClient {
616    table: String,
617    db: String,
618    conn: mysql_async::Conn,
619}
620
621impl StarrocksSchemaClient {
622    pub async fn new(
623        host: String,
624        port: String,
625        table: String,
626        db: String,
627        user: String,
628        password: String,
629    ) -> Result<Self> {
630        // username & password may contain special chars, so we need to do URL encoding on them.
631        // Otherwise, Opts::from_url may report a `Parse error`
632        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
633        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
634
635        let conn_uri = format!(
636            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
637            user,
638            password,
639            host,
640            port,
641            db,
642            STARROCK_MYSQL_PREFER_SOCKET,
643            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
644            STARROCK_MYSQL_WAIT_TIMEOUT
645        );
646        let pool = mysql_async::Pool::new(
647            Opts::from_url(&conn_uri)
648                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
649        );
650        let conn = pool
651            .get_conn()
652            .await
653            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
654
655        Ok(Self { table, db, conn })
656    }
657
658    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
659        let query = format!(
660            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
661            self.table, self.db
662        );
663        let mut query_map: HashMap<String, String> = HashMap::default();
664        self.conn
665            .query_map(query, |(column_name, column_type)| {
666                query_map.insert(column_name, column_type)
667            })
668            .await
669            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
670        Ok(query_map)
671    }
672
673    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
674        let query = format!(
675            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
676            self.table, self.db
677        );
678        let table_mode_pk: (String, String) = self
679            .conn
680            .query_map(
681                query,
682                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
683                    .as_str()
684                {
685                    // Get primary key of aggregate table from the sort_key field
686                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
687                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
688                    "AGG_KEYS" => (table_model, sort_key),
689                    _ => (table_model, primary_key),
690                },
691            )
692            .await
693            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
694            .first()
695            .ok_or_else(|| {
696                SinkError::Starrocks(format!(
697                    "Can't find schema with table {:?} and database {:?}",
698                    self.table, self.db
699                ))
700            })?
701            .clone();
702        Ok(table_mode_pk)
703    }
704}
705
706#[derive(Debug, Serialize, Deserialize)]
707pub struct StarrocksInsertResultResponse {
708    #[serde(rename = "TxnId")]
709    pub txn_id: Option<i64>,
710    #[serde(rename = "Seq")]
711    pub seq: Option<i64>,
712    #[serde(rename = "Label")]
713    pub label: Option<String>,
714    #[serde(rename = "Status")]
715    pub status: String,
716    #[serde(rename = "Message")]
717    pub message: String,
718    #[serde(rename = "NumberTotalRows")]
719    pub number_total_rows: Option<i64>,
720    #[serde(rename = "NumberLoadedRows")]
721    pub number_loaded_rows: Option<i64>,
722    #[serde(rename = "NumberFilteredRows")]
723    pub number_filtered_rows: Option<i32>,
724    #[serde(rename = "NumberUnselectedRows")]
725    pub number_unselected_rows: Option<i32>,
726    #[serde(rename = "LoadBytes")]
727    pub load_bytes: Option<i64>,
728    #[serde(rename = "LoadTimeMs")]
729    pub load_time_ms: Option<i32>,
730    #[serde(rename = "BeginTxnTimeMs")]
731    pub begin_txn_time_ms: Option<i32>,
732    #[serde(rename = "ReadDataTimeMs")]
733    pub read_data_time_ms: Option<i32>,
734    #[serde(rename = "WriteDataTimeMs")]
735    pub write_data_time_ms: Option<i32>,
736    #[serde(rename = "CommitAndPublishTimeMs")]
737    pub commit_and_publish_time_ms: Option<i32>,
738    #[serde(rename = "StreamLoadPlanTimeMs")]
739    pub stream_load_plan_time_ms: Option<i32>,
740    #[serde(rename = "ExistingJobStatus")]
741    pub existing_job_status: Option<String>,
742    #[serde(rename = "ErrorURL")]
743    pub error_url: Option<String>,
744}
745
746pub struct StarrocksClient {
747    insert: InserterInner,
748}
749impl StarrocksClient {
750    pub fn new(insert: InserterInner) -> Self {
751        Self { insert }
752    }
753
754    pub async fn write(&mut self, data: Bytes) -> Result<()> {
755        self.insert.write(data).await?;
756        Ok(())
757    }
758
759    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
760        let raw = self.insert.finish().await?;
761        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
762            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
763
764        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
765            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
766                "Insert error: {}, {}, {:?}",
767                res.status,
768                res.message,
769                res.error_url,
770            )));
771        };
772        Ok(res)
773    }
774}
775
776pub struct StarrocksTxnClient {
777    request_builder: StarrocksTxnRequestBuilder,
778}
779
780impl StarrocksTxnClient {
781    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
782        Self { request_builder }
783    }
784
785    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
786        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
787            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
788        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
789            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
790                "transaction error: {}, {}, {:?}",
791                res.status,
792                res.message,
793                res.error_url,
794            )));
795        }
796        res.label.ok_or_else(|| {
797            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
798        })
799    }
800
801    pub async fn begin(&self, label: String) -> Result<String> {
802        let res = self
803            .request_builder
804            .build_begin_request_sender(label)?
805            .send()
806            .await?;
807        self.check_response_and_extract_label(res)
808    }
809
810    pub async fn prepare(&self, label: String) -> Result<String> {
811        let res = self
812            .request_builder
813            .build_prepare_request_sender(label)?
814            .send()
815            .await?;
816        self.check_response_and_extract_label(res)
817    }
818
819    pub async fn commit(&self, label: String) -> Result<String> {
820        let res = self
821            .request_builder
822            .build_commit_request_sender(label)?
823            .send()
824            .await?;
825        self.check_response_and_extract_label(res)
826    }
827
828    pub async fn rollback(&self, label: String) -> Result<String> {
829        let res = self
830            .request_builder
831            .build_rollback_request_sender(label)?
832            .send()
833            .await?;
834        self.check_response_and_extract_label(res)
835    }
836
837    pub async fn load(&self, label: String) -> Result<InserterInner> {
838        self.request_builder.build_txn_inserter(label).await
839    }
840}
841
842struct StarrocksWriteResult(Option<String>);
843
844impl TryFrom<StarrocksWriteResult> for SinkMetadata {
845    type Error = SinkError;
846
847    fn try_from(value: StarrocksWriteResult) -> std::result::Result<Self, Self::Error> {
848        match value.0 {
849            Some(label) => {
850                let metadata = label.into_bytes();
851                Ok(SinkMetadata {
852                    metadata: Some(Serialized(SerializedMetadata { metadata })),
853                })
854            }
855            None => Ok(SinkMetadata { metadata: None }),
856        }
857    }
858}
859
860impl TryFrom<SinkMetadata> for StarrocksWriteResult {
861    type Error = SinkError;
862
863    fn try_from(value: SinkMetadata) -> std::result::Result<Self, Self::Error> {
864        if let Some(Serialized(v)) = value.metadata {
865            Ok(StarrocksWriteResult(Some(
866                String::from_utf8(v.metadata)
867                    .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
868            )))
869        } else {
870            Ok(StarrocksWriteResult(None))
871        }
872    }
873}
874
875pub struct StarrocksSinkCommitter {
876    client: Arc<StarrocksTxnClient>,
877}
878
879#[async_trait::async_trait]
880impl SinkCommitCoordinator for StarrocksSinkCommitter {
881    async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
882        tracing::info!("Starrocks commit coordinator inited.");
883        Ok(None)
884    }
885
886    async fn commit(&mut self, epoch: u64, metadata: Vec<SinkMetadata>) -> Result<()> {
887        let write_results = metadata
888            .into_iter()
889            .map(TryFrom::try_from)
890            .collect::<Result<Vec<StarrocksWriteResult>>>()?;
891
892        let txn_labels = write_results
893            .into_iter()
894            .filter_map(|v| v.0)
895            .collect::<Vec<String>>();
896
897        tracing::debug!(?epoch, ?txn_labels, "commit transaction");
898
899        if !txn_labels.is_empty() {
900            futures::future::try_join_all(
901                txn_labels
902                    .into_iter()
903                    .map(|txn_label| self.client.commit(txn_label)),
904            )
905            .await?;
906        }
907        Ok(())
908    }
909}