risingwave_connector/sink/
starrocks.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use serde_with::{DisplayFromStr, serde_as};
30use thiserror_ext::AsReport;
31use url::form_urlencoded;
32use risingwave_pb::id::ExecutorId;
33use with_options::WithOptions;
34
35use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
36use super::doris_starrocks_connector::{
37    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
38    StarrocksTxnRequestBuilder,
39};
40use super::encoder::{JsonEncoder, RowEncoder};
41use super::{
42    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
43    SinkWriterMetrics,
44};
45use crate::enforce_secret::EnforceSecret;
46use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
47use crate::sink::writer::SinkWriter;
48use crate::sink::{Result, Sink, SinkWriterParam};
49
50pub const STARROCKS_SINK: &str = "starrocks";
51const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
52const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
53const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
54
55pub const fn _default_stream_load_http_timeout_ms() -> u64 {
56    30 * 1000
57}
58
59const fn default_use_https() -> bool {
60    false
61}
62
63#[serde_as]
64#[derive(Deserialize, Debug, Clone, WithOptions)]
65pub struct StarrocksCommon {
66    /// The `StarRocks` host address.
67    #[serde(rename = "starrocks.host")]
68    pub host: String,
69    /// The port to the MySQL server of `StarRocks` FE.
70    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
71    pub mysql_port: String,
72    /// The port to the HTTP server of `StarRocks` FE.
73    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
74    pub http_port: String,
75    /// The user name used to access the `StarRocks` database.
76    #[serde(rename = "starrocks.user")]
77    pub user: String,
78    /// The password associated with the user.
79    #[serde(rename = "starrocks.password")]
80    pub password: String,
81    /// The `StarRocks` database where the target table is located
82    #[serde(rename = "starrocks.database")]
83    pub database: String,
84    /// The `StarRocks` table you want to sink data to.
85    #[serde(rename = "starrocks.table")]
86    pub table: String,
87
88    /// Whether to use https to connect to the `StarRocks` server.
89    #[serde(rename = "starrocks.use_https")]
90    #[serde(default = "default_use_https")]
91    #[serde_as(as = "DisplayFromStr")]
92    pub use_https: bool,
93}
94
95impl EnforceSecret for StarrocksCommon {
96    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
97        "starrocks.password", "starrocks.user"
98    };
99}
100
101#[serde_as]
102#[derive(Clone, Debug, Deserialize, WithOptions)]
103pub struct StarrocksConfig {
104    #[serde(flatten)]
105    pub common: StarrocksCommon,
106
107    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
108    #[serde(
109        rename = "starrocks.stream_load.http.timeout.ms",
110        default = "_default_stream_load_http_timeout_ms"
111    )]
112    #[serde_as(as = "DisplayFromStr")]
113    #[with_option(allow_alter_on_fly)]
114    pub stream_load_http_timeout_ms: u64,
115
116    /// Set this option to a positive integer n, RisingWave will try to commit data
117    /// to Starrocks at every n checkpoints by leveraging the
118    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
119    /// also, in this time, the `sink_decouple` option should be enabled as well.
120    /// Defaults to 10 if `commit_checkpoint_interval` <= 0
121    #[serde(default = "default_commit_checkpoint_interval")]
122    #[serde_as(as = "DisplayFromStr")]
123    #[with_option(allow_alter_on_fly)]
124    pub commit_checkpoint_interval: u64,
125
126    /// Enable partial update
127    #[serde(rename = "starrocks.partial_update")]
128    pub partial_update: Option<String>,
129
130    pub r#type: String, // accept "append-only" or "upsert"
131}
132
133impl EnforceSecret for StarrocksConfig {
134    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
135        StarrocksCommon::enforce_one(prop)
136    }
137}
138
139fn default_commit_checkpoint_interval() -> u64 {
140    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
141}
142
143impl StarrocksConfig {
144    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
145        let config =
146            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
147                .map_err(|e| SinkError::Config(anyhow!(e)))?;
148        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
149            return Err(SinkError::Config(anyhow!(
150                "`{}` must be {}, or {}",
151                SINK_TYPE_OPTION,
152                SINK_TYPE_APPEND_ONLY,
153                SINK_TYPE_UPSERT
154            )));
155        }
156        if config.commit_checkpoint_interval == 0 {
157            return Err(SinkError::Config(anyhow!(
158                "`commit_checkpoint_interval` must be greater than 0"
159            )));
160        }
161        Ok(config)
162    }
163}
164
165#[derive(Debug)]
166pub struct StarrocksSink {
167    pub config: StarrocksConfig,
168    schema: Schema,
169    pk_indices: Vec<usize>,
170    is_append_only: bool,
171}
172
173impl EnforceSecret for StarrocksSink {
174    fn enforce_secret<'a>(
175        prop_iter: impl Iterator<Item = &'a str>,
176    ) -> crate::error::ConnectorResult<()> {
177        for prop in prop_iter {
178            StarrocksConfig::enforce_one(prop)?;
179        }
180        Ok(())
181    }
182}
183
184impl StarrocksSink {
185    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
186        let pk_indices = param.downstream_pk_or_empty();
187        let is_append_only = param.sink_type.is_append_only();
188        Ok(Self {
189            config,
190            schema,
191            pk_indices,
192            is_append_only,
193        })
194    }
195}
196
197impl StarrocksSink {
198    fn check_column_name_and_type(
199        &self,
200        starrocks_columns_desc: HashMap<String, String>,
201    ) -> Result<()> {
202        let rw_fields_name = self.schema.fields();
203        if rw_fields_name.len() > starrocks_columns_desc.len() {
204            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
205        }
206
207        for i in rw_fields_name {
208            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
209                SinkError::Starrocks(format!(
210                    "Column name don't find in starrocks, risingwave is {:?} ",
211                    i.name
212                ))
213            })?;
214            if !Self::check_and_correct_column_type(&i.data_type, value)? {
215                return Err(SinkError::Starrocks(format!(
216                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
217                    i.name, value, i.data_type
218                )));
219            }
220        }
221        Ok(())
222    }
223
224    fn check_and_correct_column_type(
225        rw_data_type: &DataType,
226        starrocks_data_type: &String,
227    ) -> Result<bool> {
228        match rw_data_type {
229            risingwave_common::types::DataType::Boolean => {
230                Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
231            }
232            risingwave_common::types::DataType::Int16 => {
233                Ok(starrocks_data_type.contains("smallint"))
234            }
235            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
236            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
237            risingwave_common::types::DataType::Float32 => {
238                Ok(starrocks_data_type.contains("float"))
239            }
240            risingwave_common::types::DataType::Float64 => {
241                Ok(starrocks_data_type.contains("double"))
242            }
243            risingwave_common::types::DataType::Decimal => {
244                Ok(starrocks_data_type.contains("decimal"))
245            }
246            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
247            risingwave_common::types::DataType::Varchar => {
248                Ok(starrocks_data_type.contains("varchar"))
249            }
250            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
251                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
252            )),
253            risingwave_common::types::DataType::Timestamp => {
254                Ok(starrocks_data_type.contains("datetime"))
255            }
256            risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
257                "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
258            )),
259            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
260                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
261            )),
262            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
263                "STRUCT is not supported for Starrocks sink.".to_owned(),
264            )),
265            risingwave_common::types::DataType::List(list) => {
266                // For compatibility with older versions starrocks
267                if starrocks_data_type.contains("unknown") {
268                    return Ok(true);
269                }
270                let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
271                Ok(check_result && starrocks_data_type.contains("array"))
272            }
273            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
274                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
275            )),
276            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
277            risingwave_common::types::DataType::Serial => {
278                Ok(starrocks_data_type.contains("bigint"))
279            }
280            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
281                "INT256 is not supported for Starrocks sink.".to_owned(),
282            )),
283            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
284                "MAP is not supported for Starrocks sink.".to_owned(),
285            )),
286            DataType::Vector(_) => Err(SinkError::Starrocks(
287                "VECTOR is not supported for Starrocks sink.".to_owned(),
288            )),
289        }
290    }
291}
292
293impl Sink for StarrocksSink {
294    type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
295
296    const SINK_NAME: &'static str = STARROCKS_SINK;
297
298    async fn validate(&self) -> Result<()> {
299        if !self.is_append_only && self.pk_indices.is_empty() {
300            return Err(SinkError::Config(anyhow!(
301                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
302            )));
303        }
304        // check reachability
305        let mut client = StarrocksSchemaClient::new(
306            self.config.common.host.clone(),
307            self.config.common.mysql_port.clone(),
308            self.config.common.table.clone(),
309            self.config.common.database.clone(),
310            self.config.common.user.clone(),
311            self.config.common.password.clone(),
312        )
313        .await?;
314        let (read_model, pks) = client.get_pk_from_starrocks().await?;
315
316        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
317            return Err(SinkError::Config(anyhow!(
318                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
319            )));
320        }
321
322        for (index, filed) in self.schema.fields().iter().enumerate() {
323            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
324                return Err(SinkError::Starrocks(format!(
325                    "Can't find pk {:?} in starrocks",
326                    filed.name
327                )));
328            }
329        }
330
331        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
332
333        self.check_column_name_and_type(starrocks_columns_desc)?;
334        Ok(())
335    }
336
337    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
338        StarrocksConfig::from_btreemap(config.clone())?;
339        Ok(())
340    }
341
342    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
343        let commit_checkpoint_interval =
344            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
345                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
346            );
347
348        let writer = StarrocksSinkWriter::new(
349            self.config.clone(),
350            self.schema.clone(),
351            self.pk_indices.clone(),
352            self.is_append_only,
353            writer_param.executor_id,
354        )?;
355
356        let metrics = SinkWriterMetrics::new(&writer_param);
357
358        Ok(DecoupleCheckpointLogSinkerOf::new(
359            writer,
360            metrics,
361            commit_checkpoint_interval,
362        ))
363    }
364}
365
366pub struct StarrocksSinkWriter {
367    pub config: StarrocksConfig,
368    #[expect(dead_code)]
369    schema: Schema,
370    #[expect(dead_code)]
371    pk_indices: Vec<usize>,
372    is_append_only: bool,
373    client: Option<StarrocksClient>,
374    txn_client: Arc<StarrocksTxnClient>,
375    row_encoder: JsonEncoder,
376    executor_id: ExecutorId,
377    curr_txn_label: Option<String>,
378}
379
380impl TryFrom<SinkParam> for StarrocksSink {
381    type Error = SinkError;
382
383    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
384        let schema = param.schema();
385        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
386        StarrocksSink::new(param, config, schema)
387    }
388}
389
390impl StarrocksSinkWriter {
391    pub fn new(
392        config: StarrocksConfig,
393        schema: Schema,
394        pk_indices: Vec<usize>,
395        is_append_only: bool,
396        executor_id: ExecutorId,
397    ) -> Result<Self> {
398        let mut field_names = schema.names_str();
399        if !is_append_only {
400            field_names.push(STARROCKS_DELETE_SIGN);
401        };
402        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
403        // a field name being a reserved word. For example, `order`, 'from`, etc.
404        let field_names = field_names
405            .into_iter()
406            .map(|name| format!("`{}`", name))
407            .collect::<Vec<String>>();
408        let field_names_str = field_names
409            .iter()
410            .map(|name| name.as_str())
411            .collect::<Vec<&str>>();
412
413        let header = HeaderBuilder::new()
414            .add_common_header()
415            .set_user_password(config.common.user.clone(), config.common.password.clone())
416            .add_json_format()
417            .set_partial_update(config.partial_update.clone())
418            .set_columns_name(field_names_str)
419            .set_db(config.common.database.clone())
420            .set_table(config.common.table.clone())
421            .build();
422
423        let url = if config.common.use_https {
424            format!("https://{}:{}", config.common.host, config.common.http_port)
425        } else {
426            format!("http://{}:{}", config.common.host, config.common.http_port)
427        };
428        let txn_request_builder =
429            StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
430
431        Ok(Self {
432            config,
433            schema: schema.clone(),
434            pk_indices,
435            is_append_only,
436            client: None,
437            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
438            row_encoder: JsonEncoder::new_with_starrocks(schema, None),
439            executor_id,
440            curr_txn_label: None,
441        })
442    }
443
444    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
445        for (op, row) in chunk.rows() {
446            if op != Op::Insert {
447                continue;
448            }
449            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
450            self.client
451                .as_mut()
452                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
453                .write(row_json_string.into())
454                .await?;
455        }
456        Ok(())
457    }
458
459    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
460        for (op, row) in chunk.rows() {
461            match op {
462                Op::Insert => {
463                    let mut row_json_value = self.row_encoder.encode(row)?;
464                    row_json_value.insert(
465                        STARROCKS_DELETE_SIGN.to_owned(),
466                        Value::String("0".to_owned()),
467                    );
468                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
469                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
470                    })?;
471                    self.client
472                        .as_mut()
473                        .ok_or_else(|| {
474                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
475                        })?
476                        .write(row_json_string.into())
477                        .await?;
478                }
479                Op::Delete => {
480                    let mut row_json_value = self.row_encoder.encode(row)?;
481                    row_json_value.insert(
482                        STARROCKS_DELETE_SIGN.to_owned(),
483                        Value::String("1".to_owned()),
484                    );
485                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
486                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
487                    })?;
488                    self.client
489                        .as_mut()
490                        .ok_or_else(|| {
491                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
492                        })?
493                        .write(row_json_string.into())
494                        .await?;
495                }
496                Op::UpdateDelete => {}
497                Op::UpdateInsert => {
498                    let mut row_json_value = self.row_encoder.encode(row)?;
499                    row_json_value.insert(
500                        STARROCKS_DELETE_SIGN.to_owned(),
501                        Value::String("0".to_owned()),
502                    );
503                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
504                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
505                    })?;
506                    self.client
507                        .as_mut()
508                        .ok_or_else(|| {
509                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
510                        })?
511                        .write(row_json_string.into())
512                        .await?;
513                }
514            }
515        }
516        Ok(())
517    }
518
519    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
520    #[inline(always)]
521    fn new_txn_label(&self) -> String {
522        format!(
523            "rw-txn-{}-{}",
524            self.executor_id,
525            chrono::Utc::now().timestamp_micros()
526        )
527    }
528
529    async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
530        tracing::debug!(?txn_label, "prepare transaction");
531        let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
532        if txn_label != txn_label_res {
533            return Err(SinkError::Starrocks(format!(
534                "label {} returned from prepare transaction {} differs from the current one",
535                txn_label, txn_label_res
536            )));
537        }
538        tracing::debug!(?txn_label, "commit transaction");
539        let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
540        if txn_label != txn_label_res {
541            return Err(SinkError::Starrocks(format!(
542                "label {} returned from commit transaction {} differs from the current one",
543                txn_label, txn_label_res
544            )));
545        }
546        Ok(())
547    }
548}
549
550impl Drop for StarrocksSinkWriter {
551    fn drop(&mut self) {
552        if let Some(txn_label) = self.curr_txn_label.take() {
553            let txn_client = self.txn_client.clone();
554            tokio::spawn(async move {
555                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
556                    tracing::error!(
557                        "starrocks rollback transaction error: {:?}, txn label: {}",
558                        e.as_report(),
559                        txn_label
560                    );
561                }
562            });
563        }
564    }
565}
566
567#[async_trait]
568impl SinkWriter for StarrocksSinkWriter {
569    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
570        Ok(())
571    }
572
573    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
574        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
575        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
576        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
577        if self.curr_txn_label.is_none() {
578            let txn_label = self.new_txn_label();
579            tracing::debug!(?txn_label, "begin transaction");
580            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
581            if txn_label != txn_label_res {
582                return Err(SinkError::Starrocks(format!(
583                    "label {} returned from StarRocks {} differs from generated one",
584                    txn_label, txn_label_res
585                )));
586            }
587            self.curr_txn_label = Some(txn_label.clone());
588        }
589        if self.client.is_none() {
590            let txn_label = self.curr_txn_label.clone();
591            self.client = Some(StarrocksClient::new(
592                self.txn_client.load(txn_label.unwrap()).await?,
593            ));
594        }
595        if self.is_append_only {
596            self.append_only(chunk).await
597        } else {
598            self.upsert(chunk).await
599        }
600    }
601
602    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
603        if let Some(client) = self.client.take() {
604            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
605            // one or more load requests should be made within one commit_checkpoint_interval period.
606            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
607            // Thus, only one version will be produced when the transaction is committed. See Stream Load
608            // transaction interface for more information.
609            client.finish().await?;
610        }
611
612        if is_checkpoint
613            && let Some(txn_label) = self.curr_txn_label.take()
614            && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
615        {
616            match self.txn_client.rollback(txn_label.clone()).await {
617                Ok(_) => tracing::warn!(
618                    ?txn_label,
619                    "transaction is successfully rolled back due to commit failure"
620                ),
621                Err(err) => {
622                    tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
623                }
624            }
625
626            return Err(err);
627        }
628        Ok(())
629    }
630
631    async fn abort(&mut self) -> Result<()> {
632        if let Some(txn_label) = self.curr_txn_label.take() {
633            tracing::debug!(?txn_label, "rollback transaction");
634            self.txn_client.rollback(txn_label).await?;
635        }
636        Ok(())
637    }
638}
639
640pub struct StarrocksSchemaClient {
641    table: String,
642    db: String,
643    conn: mysql_async::Conn,
644}
645
646impl StarrocksSchemaClient {
647    pub async fn new(
648        host: String,
649        port: String,
650        table: String,
651        db: String,
652        user: String,
653        password: String,
654    ) -> Result<Self> {
655        // username & password may contain special chars, so we need to do URL encoding on them.
656        // Otherwise, Opts::from_url may report a `Parse error`
657        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
658        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
659
660        let conn_uri = format!(
661            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
662            user,
663            password,
664            host,
665            port,
666            db,
667            STARROCK_MYSQL_PREFER_SOCKET,
668            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
669            STARROCK_MYSQL_WAIT_TIMEOUT
670        );
671        let pool = mysql_async::Pool::new(
672            Opts::from_url(&conn_uri)
673                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
674        );
675        let conn = pool
676            .get_conn()
677            .await
678            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
679
680        Ok(Self { table, db, conn })
681    }
682
683    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
684        let query = format!(
685            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
686            self.table, self.db
687        );
688        let mut query_map: HashMap<String, String> = HashMap::default();
689        self.conn
690            .query_map(query, |(column_name, column_type)| {
691                query_map.insert(column_name, column_type)
692            })
693            .await
694            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
695        Ok(query_map)
696    }
697
698    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
699        let query = format!(
700            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
701            self.table, self.db
702        );
703        let table_mode_pk: (String, String) = self
704            .conn
705            .query_map(
706                query,
707                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
708                    .as_str()
709                {
710                    // Get primary key of aggregate table from the sort_key field
711                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
712                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
713                    "AGG_KEYS" => (table_model, sort_key),
714                    _ => (table_model, primary_key),
715                },
716            )
717            .await
718            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
719            .first()
720            .ok_or_else(|| {
721                SinkError::Starrocks(format!(
722                    "Can't find schema with table {:?} and database {:?}",
723                    self.table, self.db
724                ))
725            })?
726            .clone();
727        Ok(table_mode_pk)
728    }
729}
730
731#[derive(Debug, Serialize, Deserialize)]
732pub struct StarrocksInsertResultResponse {
733    #[serde(rename = "TxnId")]
734    pub txn_id: Option<i64>,
735    #[serde(rename = "Seq")]
736    pub seq: Option<i64>,
737    #[serde(rename = "Label")]
738    pub label: Option<String>,
739    #[serde(rename = "Status")]
740    pub status: String,
741    #[serde(rename = "Message")]
742    pub message: String,
743    #[serde(rename = "NumberTotalRows")]
744    pub number_total_rows: Option<i64>,
745    #[serde(rename = "NumberLoadedRows")]
746    pub number_loaded_rows: Option<i64>,
747    #[serde(rename = "NumberFilteredRows")]
748    pub number_filtered_rows: Option<i32>,
749    #[serde(rename = "NumberUnselectedRows")]
750    pub number_unselected_rows: Option<i32>,
751    #[serde(rename = "LoadBytes")]
752    pub load_bytes: Option<i64>,
753    #[serde(rename = "LoadTimeMs")]
754    pub load_time_ms: Option<i32>,
755    #[serde(rename = "BeginTxnTimeMs")]
756    pub begin_txn_time_ms: Option<i32>,
757    #[serde(rename = "ReadDataTimeMs")]
758    pub read_data_time_ms: Option<i32>,
759    #[serde(rename = "WriteDataTimeMs")]
760    pub write_data_time_ms: Option<i32>,
761    #[serde(rename = "CommitAndPublishTimeMs")]
762    pub commit_and_publish_time_ms: Option<i32>,
763    #[serde(rename = "StreamLoadPlanTimeMs")]
764    pub stream_load_plan_time_ms: Option<i32>,
765    #[serde(rename = "ExistingJobStatus")]
766    pub existing_job_status: Option<String>,
767    #[serde(rename = "ErrorURL")]
768    pub error_url: Option<String>,
769}
770
771pub struct StarrocksClient {
772    insert: InserterInner,
773}
774impl StarrocksClient {
775    pub fn new(insert: InserterInner) -> Self {
776        Self { insert }
777    }
778
779    pub async fn write(&mut self, data: Bytes) -> Result<()> {
780        self.insert.write(data).await?;
781        Ok(())
782    }
783
784    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
785        let raw = self.insert.finish().await?;
786        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
787            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
788
789        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
790            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
791                "Insert error: {}, {}, {:?}",
792                res.status,
793                res.message,
794                res.error_url,
795            )));
796        };
797        Ok(res)
798    }
799}
800
801pub struct StarrocksTxnClient {
802    request_builder: StarrocksTxnRequestBuilder,
803}
804
805impl StarrocksTxnClient {
806    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
807        Self { request_builder }
808    }
809
810    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
811        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
812            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
813        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
814            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
815                "transaction error: {}, {}, {:?}",
816                res.status,
817                res.message,
818                res.error_url,
819            )));
820        }
821        res.label.ok_or_else(|| {
822            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
823        })
824    }
825
826    pub async fn begin(&self, label: String) -> Result<String> {
827        let res = self
828            .request_builder
829            .build_begin_request_sender(label)?
830            .send()
831            .await?;
832        self.check_response_and_extract_label(res)
833    }
834
835    pub async fn prepare(&self, label: String) -> Result<String> {
836        let res = self
837            .request_builder
838            .build_prepare_request_sender(label)?
839            .send()
840            .await?;
841        self.check_response_and_extract_label(res)
842    }
843
844    pub async fn commit(&self, label: String) -> Result<String> {
845        let res = self
846            .request_builder
847            .build_commit_request_sender(label)?
848            .send()
849            .await?;
850        self.check_response_and_extract_label(res)
851    }
852
853    pub async fn rollback(&self, label: String) -> Result<String> {
854        let res = self
855            .request_builder
856            .build_rollback_request_sender(label)?
857            .send()
858            .await?;
859        self.check_response_and_extract_label(res)
860    }
861
862    pub async fn load(&self, label: String) -> Result<InserterInner> {
863        self.request_builder.build_txn_inserter(label).await
864    }
865}