risingwave_connector/sink/
starrocks.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::Deserialize;
28use serde_derive::Serialize;
29use serde_json::Value;
30use serde_with::{DisplayFromStr, serde_as};
31use thiserror_ext::AsReport;
32use url::form_urlencoded;
33use with_options::WithOptions;
34
35use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
36use super::doris_starrocks_connector::{
37    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
38    StarrocksTxnRequestBuilder,
39};
40use super::encoder::{JsonEncoder, RowEncoder};
41use super::{
42    DummySinkCommitCoordinator, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
43    SinkError, SinkParam, SinkWriterMetrics,
44};
45use crate::enforce_secret::EnforceSecret;
46use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
47use crate::sink::{Result, Sink, SinkWriter, SinkWriterParam};
48
49pub const STARROCKS_SINK: &str = "starrocks";
50const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
51const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
52const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
53
54const fn _default_stream_load_http_timeout_ms() -> u64 {
55    30 * 1000
56}
57
58#[derive(Deserialize, Debug, Clone, WithOptions)]
59pub struct StarrocksCommon {
60    /// The `StarRocks` host address.
61    #[serde(rename = "starrocks.host")]
62    pub host: String,
63    /// The port to the MySQL server of `StarRocks` FE.
64    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
65    pub mysql_port: String,
66    /// The port to the HTTP server of `StarRocks` FE.
67    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
68    pub http_port: String,
69    /// The user name used to access the `StarRocks` database.
70    #[serde(rename = "starrocks.user")]
71    pub user: String,
72    /// The password associated with the user.
73    #[serde(rename = "starrocks.password")]
74    pub password: String,
75    /// The `StarRocks` database where the target table is located
76    #[serde(rename = "starrocks.database")]
77    pub database: String,
78    /// The `StarRocks` table you want to sink data to.
79    #[serde(rename = "starrocks.table")]
80    pub table: String,
81}
82
83impl EnforceSecret for StarrocksCommon {
84    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
85        "starrocks.password", "starrocks.user"
86    };
87}
88
89#[serde_as]
90#[derive(Clone, Debug, Deserialize, WithOptions)]
91pub struct StarrocksConfig {
92    #[serde(flatten)]
93    pub common: StarrocksCommon,
94
95    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
96    #[serde(
97        rename = "starrocks.stream_load.http.timeout.ms",
98        default = "_default_stream_load_http_timeout_ms"
99    )]
100    #[serde_as(as = "DisplayFromStr")]
101    pub stream_load_http_timeout_ms: u64,
102
103    /// Set this option to a positive integer n, RisingWave will try to commit data
104    /// to Starrocks at every n checkpoints by leveraging the
105    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
106    /// also, in this time, the `sink_decouple` option should be enabled as well.
107    /// Defaults to 10 if commit_checkpoint_interval <= 0
108    #[serde(default = "default_commit_checkpoint_interval")]
109    #[serde_as(as = "DisplayFromStr")]
110    pub commit_checkpoint_interval: u64,
111
112    /// Enable partial update
113    #[serde(rename = "starrocks.partial_update")]
114    pub partial_update: Option<String>,
115
116    pub r#type: String, // accept "append-only" or "upsert"
117}
118
119impl EnforceSecret for StarrocksConfig {
120    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
121        StarrocksCommon::enforce_one(prop)
122    }
123}
124
125fn default_commit_checkpoint_interval() -> u64 {
126    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
127}
128
129impl StarrocksConfig {
130    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
131        let config =
132            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
133                .map_err(|e| SinkError::Config(anyhow!(e)))?;
134        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
135            return Err(SinkError::Config(anyhow!(
136                "`{}` must be {}, or {}",
137                SINK_TYPE_OPTION,
138                SINK_TYPE_APPEND_ONLY,
139                SINK_TYPE_UPSERT
140            )));
141        }
142        if config.commit_checkpoint_interval == 0 {
143            return Err(SinkError::Config(anyhow!(
144                "`commit_checkpoint_interval` must be greater than 0"
145            )));
146        }
147        Ok(config)
148    }
149}
150
151#[derive(Debug)]
152pub struct StarrocksSink {
153    pub config: StarrocksConfig,
154    schema: Schema,
155    pk_indices: Vec<usize>,
156    is_append_only: bool,
157}
158
159impl EnforceSecret for StarrocksSink {
160    fn enforce_secret<'a>(
161        prop_iter: impl Iterator<Item = &'a str>,
162    ) -> crate::error::ConnectorResult<()> {
163        for prop in prop_iter {
164            StarrocksConfig::enforce_one(prop)?;
165        }
166        Ok(())
167    }
168}
169
170impl StarrocksSink {
171    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
172        let pk_indices = param.downstream_pk.clone();
173        let is_append_only = param.sink_type.is_append_only();
174        Ok(Self {
175            config,
176            schema,
177            pk_indices,
178            is_append_only,
179        })
180    }
181}
182
183impl StarrocksSink {
184    fn check_column_name_and_type(
185        &self,
186        starrocks_columns_desc: HashMap<String, String>,
187    ) -> Result<()> {
188        let rw_fields_name = self.schema.fields();
189        if rw_fields_name.len() > starrocks_columns_desc.len() {
190            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
191        }
192
193        for i in rw_fields_name {
194            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
195                SinkError::Starrocks(format!(
196                    "Column name don't find in starrocks, risingwave is {:?} ",
197                    i.name
198                ))
199            })?;
200            if !Self::check_and_correct_column_type(&i.data_type, value)? {
201                return Err(SinkError::Starrocks(format!(
202                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
203                    i.name, value, i.data_type
204                )));
205            }
206        }
207        Ok(())
208    }
209
210    fn check_and_correct_column_type(
211        rw_data_type: &DataType,
212        starrocks_data_type: &String,
213    ) -> Result<bool> {
214        match rw_data_type {
215            risingwave_common::types::DataType::Boolean => {
216                Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
217            }
218            risingwave_common::types::DataType::Int16 => {
219                Ok(starrocks_data_type.contains("smallint"))
220            }
221            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
222            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
223            risingwave_common::types::DataType::Float32 => {
224                Ok(starrocks_data_type.contains("float"))
225            }
226            risingwave_common::types::DataType::Float64 => {
227                Ok(starrocks_data_type.contains("double"))
228            }
229            risingwave_common::types::DataType::Decimal => {
230                Ok(starrocks_data_type.contains("decimal"))
231            }
232            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
233            risingwave_common::types::DataType::Varchar => {
234                Ok(starrocks_data_type.contains("varchar"))
235            }
236            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
237                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
238            )),
239            risingwave_common::types::DataType::Timestamp => {
240                Ok(starrocks_data_type.contains("datetime"))
241            }
242            risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
243                "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
244            )),
245            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
246                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
247            )),
248            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
249                "STRUCT is not supported for Starrocks sink.".to_owned(),
250            )),
251            risingwave_common::types::DataType::List(list) => {
252                // For compatibility with older versions starrocks
253                if starrocks_data_type.contains("unknown") {
254                    return Ok(true);
255                }
256                let check_result = Self::check_and_correct_column_type(list.as_ref(), starrocks_data_type)?;
257                Ok(check_result && starrocks_data_type.contains("array"))
258            }
259            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
260                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
261            )),
262            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
263            risingwave_common::types::DataType::Serial => {
264                Ok(starrocks_data_type.contains("bigint"))
265            }
266            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
267                "INT256 is not supported for Starrocks sink.".to_owned(),
268            )),
269            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
270                "MAP is not supported for Starrocks sink.".to_owned(),
271            )),
272            DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
273        }
274    }
275}
276
277impl Sink for StarrocksSink {
278    type Coordinator = DummySinkCommitCoordinator;
279    type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
280
281    const SINK_ALTER_CONFIG_LIST: &'static [&'static str] = &["commit_checkpoint_interval"];
282    const SINK_NAME: &'static str = STARROCKS_SINK;
283
284    async fn validate(&self) -> Result<()> {
285        if !self.is_append_only && self.pk_indices.is_empty() {
286            return Err(SinkError::Config(anyhow!(
287                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
288            )));
289        }
290        // check reachability
291        let mut client = StarrocksSchemaClient::new(
292            self.config.common.host.clone(),
293            self.config.common.mysql_port.clone(),
294            self.config.common.table.clone(),
295            self.config.common.database.clone(),
296            self.config.common.user.clone(),
297            self.config.common.password.clone(),
298        )
299        .await?;
300        let (read_model, pks) = client.get_pk_from_starrocks().await?;
301
302        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
303            return Err(SinkError::Config(anyhow!(
304                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
305            )));
306        }
307
308        for (index, filed) in self.schema.fields().iter().enumerate() {
309            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
310                return Err(SinkError::Starrocks(format!(
311                    "Can't find pk {:?} in starrocks",
312                    filed.name
313                )));
314            }
315        }
316
317        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
318
319        self.check_column_name_and_type(starrocks_columns_desc)?;
320        Ok(())
321    }
322
323    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
324        StarrocksConfig::from_btreemap(config.clone())?;
325        Ok(())
326    }
327
328    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
329        let commit_checkpoint_interval =
330            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
331                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
332            );
333
334        let writer = StarrocksSinkWriter::new(
335            self.config.clone(),
336            self.schema.clone(),
337            self.pk_indices.clone(),
338            self.is_append_only,
339            writer_param.executor_id,
340        )?;
341
342        let metrics = SinkWriterMetrics::new(&writer_param);
343
344        Ok(DecoupleCheckpointLogSinkerOf::new(
345            writer,
346            metrics,
347            commit_checkpoint_interval,
348        ))
349    }
350}
351
352pub struct StarrocksSinkWriter {
353    pub config: StarrocksConfig,
354    #[expect(dead_code)]
355    schema: Schema,
356    #[expect(dead_code)]
357    pk_indices: Vec<usize>,
358    is_append_only: bool,
359    client: Option<StarrocksClient>,
360    txn_client: Arc<StarrocksTxnClient>,
361    row_encoder: JsonEncoder,
362    executor_id: u64,
363    curr_txn_label: Option<String>,
364}
365
366impl TryFrom<SinkParam> for StarrocksSink {
367    type Error = SinkError;
368
369    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
370        let schema = param.schema();
371        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
372        StarrocksSink::new(param, config, schema)
373    }
374}
375
376impl StarrocksSinkWriter {
377    pub fn new(
378        config: StarrocksConfig,
379        schema: Schema,
380        pk_indices: Vec<usize>,
381        is_append_only: bool,
382        executor_id: u64,
383    ) -> Result<Self> {
384        let mut field_names = schema.names_str();
385        if !is_append_only {
386            field_names.push(STARROCKS_DELETE_SIGN);
387        };
388        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
389        // a field name being a reserved word. For example, `order`, 'from`, etc.
390        let field_names = field_names
391            .into_iter()
392            .map(|name| format!("`{}`", name))
393            .collect::<Vec<String>>();
394        let field_names_str = field_names
395            .iter()
396            .map(|name| name.as_str())
397            .collect::<Vec<&str>>();
398
399        let header = HeaderBuilder::new()
400            .add_common_header()
401            .set_user_password(config.common.user.clone(), config.common.password.clone())
402            .add_json_format()
403            .set_partial_update(config.partial_update.clone())
404            .set_columns_name(field_names_str)
405            .set_db(config.common.database.clone())
406            .set_table(config.common.table.clone())
407            .build();
408
409        let txn_request_builder = StarrocksTxnRequestBuilder::new(
410            format!("http://{}:{}", config.common.host, config.common.http_port),
411            header,
412            config.stream_load_http_timeout_ms,
413        )?;
414
415        Ok(Self {
416            config,
417            schema: schema.clone(),
418            pk_indices,
419            is_append_only,
420            client: None,
421            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
422            row_encoder: JsonEncoder::new_with_starrocks(schema, None),
423            executor_id,
424            curr_txn_label: None,
425        })
426    }
427
428    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
429        for (op, row) in chunk.rows() {
430            if op != Op::Insert {
431                continue;
432            }
433            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
434            self.client
435                .as_mut()
436                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
437                .write(row_json_string.into())
438                .await?;
439        }
440        Ok(())
441    }
442
443    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
444        for (op, row) in chunk.rows() {
445            match op {
446                Op::Insert => {
447                    let mut row_json_value = self.row_encoder.encode(row)?;
448                    row_json_value.insert(
449                        STARROCKS_DELETE_SIGN.to_owned(),
450                        Value::String("0".to_owned()),
451                    );
452                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
453                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
454                    })?;
455                    self.client
456                        .as_mut()
457                        .ok_or_else(|| {
458                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
459                        })?
460                        .write(row_json_string.into())
461                        .await?;
462                }
463                Op::Delete => {
464                    let mut row_json_value = self.row_encoder.encode(row)?;
465                    row_json_value.insert(
466                        STARROCKS_DELETE_SIGN.to_owned(),
467                        Value::String("1".to_owned()),
468                    );
469                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
470                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
471                    })?;
472                    self.client
473                        .as_mut()
474                        .ok_or_else(|| {
475                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
476                        })?
477                        .write(row_json_string.into())
478                        .await?;
479                }
480                Op::UpdateDelete => {}
481                Op::UpdateInsert => {
482                    let mut row_json_value = self.row_encoder.encode(row)?;
483                    row_json_value.insert(
484                        STARROCKS_DELETE_SIGN.to_owned(),
485                        Value::String("0".to_owned()),
486                    );
487                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
488                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
489                    })?;
490                    self.client
491                        .as_mut()
492                        .ok_or_else(|| {
493                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
494                        })?
495                        .write(row_json_string.into())
496                        .await?;
497                }
498            }
499        }
500        Ok(())
501    }
502
503    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
504    #[inline(always)]
505    fn new_txn_label(&self) -> String {
506        format!(
507            "rw-txn-{}-{}",
508            self.executor_id,
509            chrono::Utc::now().timestamp_micros()
510        )
511    }
512
513    async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
514        tracing::debug!(?txn_label, "prepare transaction");
515        let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
516        if txn_label != txn_label_res {
517            return Err(SinkError::Starrocks(format!(
518                "label {} returned from prepare transaction {} differs from the current one",
519                txn_label, txn_label_res
520            )));
521        }
522        tracing::debug!(?txn_label, "commit transaction");
523        let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
524        if txn_label != txn_label_res {
525            return Err(SinkError::Starrocks(format!(
526                "label {} returned from commit transaction {} differs from the current one",
527                txn_label, txn_label_res
528            )));
529        }
530        Ok(())
531    }
532}
533
534impl Drop for StarrocksSinkWriter {
535    fn drop(&mut self) {
536        if let Some(txn_label) = self.curr_txn_label.take() {
537            let txn_client = self.txn_client.clone();
538            tokio::spawn(async move {
539                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
540                    tracing::error!(
541                        "starrocks rollback transaction error: {:?}, txn label: {}",
542                        e.as_report(),
543                        txn_label
544                    );
545                }
546            });
547        }
548    }
549}
550
551#[async_trait]
552impl SinkWriter for StarrocksSinkWriter {
553    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
554        Ok(())
555    }
556
557    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
558        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
559        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
560        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
561        if self.curr_txn_label.is_none() {
562            let txn_label = self.new_txn_label();
563            tracing::debug!(?txn_label, "begin transaction");
564            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
565            if txn_label != txn_label_res {
566                return Err(SinkError::Starrocks(format!(
567                    "label {} returned from StarRocks {} differs from generated one",
568                    txn_label, txn_label_res
569                )));
570            }
571            self.curr_txn_label = Some(txn_label.clone());
572        }
573        if self.client.is_none() {
574            let txn_label = self.curr_txn_label.clone();
575            self.client = Some(StarrocksClient::new(
576                self.txn_client.load(txn_label.unwrap()).await?,
577            ));
578        }
579        if self.is_append_only {
580            self.append_only(chunk).await
581        } else {
582            self.upsert(chunk).await
583        }
584    }
585
586    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
587        if let Some(client) = self.client.take() {
588            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
589            // one or more load requests should be made within one commit_checkpoint_interval period.
590            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
591            // Thus, only one version will be produced when the transaction is committed. See Stream Load
592            // transaction interface for more information.
593            client.finish().await?;
594        }
595
596        if is_checkpoint && let Some(txn_label) = self.curr_txn_label.take() {
597            if let Err(err) = self.prepare_and_commit(txn_label.clone()).await {
598                match self.txn_client.rollback(txn_label.clone()).await {
599                    Ok(_) => tracing::warn!(
600                        ?txn_label,
601                        "transaction is successfully rolled back due to commit failure"
602                    ),
603                    Err(err) => {
604                        tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
605                    }
606                }
607
608                return Err(err);
609            }
610        }
611        Ok(())
612    }
613
614    async fn abort(&mut self) -> Result<()> {
615        if let Some(txn_label) = self.curr_txn_label.take() {
616            tracing::debug!(?txn_label, "rollback transaction");
617            self.txn_client.rollback(txn_label).await?;
618        }
619        Ok(())
620    }
621}
622
623pub struct StarrocksSchemaClient {
624    table: String,
625    db: String,
626    conn: mysql_async::Conn,
627}
628
629impl StarrocksSchemaClient {
630    pub async fn new(
631        host: String,
632        port: String,
633        table: String,
634        db: String,
635        user: String,
636        password: String,
637    ) -> Result<Self> {
638        // username & password may contain special chars, so we need to do URL encoding on them.
639        // Otherwise, Opts::from_url may report a `Parse error`
640        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
641        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
642
643        let conn_uri = format!(
644            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
645            user,
646            password,
647            host,
648            port,
649            db,
650            STARROCK_MYSQL_PREFER_SOCKET,
651            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
652            STARROCK_MYSQL_WAIT_TIMEOUT
653        );
654        let pool = mysql_async::Pool::new(
655            Opts::from_url(&conn_uri)
656                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
657        );
658        let conn = pool
659            .get_conn()
660            .await
661            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
662
663        Ok(Self { table, db, conn })
664    }
665
666    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
667        let query = format!(
668            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
669            self.table, self.db
670        );
671        let mut query_map: HashMap<String, String> = HashMap::default();
672        self.conn
673            .query_map(query, |(column_name, column_type)| {
674                query_map.insert(column_name, column_type)
675            })
676            .await
677            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
678        Ok(query_map)
679    }
680
681    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
682        let query = format!(
683            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
684            self.table, self.db
685        );
686        let table_mode_pk: (String, String) = self
687            .conn
688            .query_map(
689                query,
690                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
691                    .as_str()
692                {
693                    // Get primary key of aggregate table from the sort_key field
694                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
695                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
696                    "AGG_KEYS" => (table_model, sort_key),
697                    _ => (table_model, primary_key),
698                },
699            )
700            .await
701            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
702            .first()
703            .ok_or_else(|| {
704                SinkError::Starrocks(format!(
705                    "Can't find schema with table {:?} and database {:?}",
706                    self.table, self.db
707                ))
708            })?
709            .clone();
710        Ok(table_mode_pk)
711    }
712}
713
714#[derive(Debug, Serialize, Deserialize)]
715pub struct StarrocksInsertResultResponse {
716    #[serde(rename = "TxnId")]
717    pub txn_id: Option<i64>,
718    #[serde(rename = "Seq")]
719    pub seq: Option<i64>,
720    #[serde(rename = "Label")]
721    pub label: Option<String>,
722    #[serde(rename = "Status")]
723    pub status: String,
724    #[serde(rename = "Message")]
725    pub message: String,
726    #[serde(rename = "NumberTotalRows")]
727    pub number_total_rows: Option<i64>,
728    #[serde(rename = "NumberLoadedRows")]
729    pub number_loaded_rows: Option<i64>,
730    #[serde(rename = "NumberFilteredRows")]
731    pub number_filtered_rows: Option<i32>,
732    #[serde(rename = "NumberUnselectedRows")]
733    pub number_unselected_rows: Option<i32>,
734    #[serde(rename = "LoadBytes")]
735    pub load_bytes: Option<i64>,
736    #[serde(rename = "LoadTimeMs")]
737    pub load_time_ms: Option<i32>,
738    #[serde(rename = "BeginTxnTimeMs")]
739    pub begin_txn_time_ms: Option<i32>,
740    #[serde(rename = "ReadDataTimeMs")]
741    pub read_data_time_ms: Option<i32>,
742    #[serde(rename = "WriteDataTimeMs")]
743    pub write_data_time_ms: Option<i32>,
744    #[serde(rename = "CommitAndPublishTimeMs")]
745    pub commit_and_publish_time_ms: Option<i32>,
746    #[serde(rename = "StreamLoadPlanTimeMs")]
747    pub stream_load_plan_time_ms: Option<i32>,
748    #[serde(rename = "ExistingJobStatus")]
749    pub existing_job_status: Option<String>,
750    #[serde(rename = "ErrorURL")]
751    pub error_url: Option<String>,
752}
753
754pub struct StarrocksClient {
755    insert: InserterInner,
756}
757impl StarrocksClient {
758    pub fn new(insert: InserterInner) -> Self {
759        Self { insert }
760    }
761
762    pub async fn write(&mut self, data: Bytes) -> Result<()> {
763        self.insert.write(data).await?;
764        Ok(())
765    }
766
767    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
768        let raw = self.insert.finish().await?;
769        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
770            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
771
772        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
773            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
774                "Insert error: {}, {}, {:?}",
775                res.status,
776                res.message,
777                res.error_url,
778            )));
779        };
780        Ok(res)
781    }
782}
783
784pub struct StarrocksTxnClient {
785    request_builder: StarrocksTxnRequestBuilder,
786}
787
788impl StarrocksTxnClient {
789    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
790        Self { request_builder }
791    }
792
793    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
794        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
795            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
796        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
797            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
798                "transaction error: {}, {}, {:?}",
799                res.status,
800                res.message,
801                res.error_url,
802            )));
803        }
804        res.label.ok_or_else(|| {
805            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
806        })
807    }
808
809    pub async fn begin(&self, label: String) -> Result<String> {
810        let res = self
811            .request_builder
812            .build_begin_request_sender(label)?
813            .send()
814            .await?;
815        self.check_response_and_extract_label(res)
816    }
817
818    pub async fn prepare(&self, label: String) -> Result<String> {
819        let res = self
820            .request_builder
821            .build_prepare_request_sender(label)?
822            .send()
823            .await?;
824        self.check_response_and_extract_label(res)
825    }
826
827    pub async fn commit(&self, label: String) -> Result<String> {
828        let res = self
829            .request_builder
830            .build_commit_request_sender(label)?
831            .send()
832            .await?;
833        self.check_response_and_extract_label(res)
834    }
835
836    pub async fn rollback(&self, label: String) -> Result<String> {
837        let res = self
838            .request_builder
839            .build_rollback_request_sender(label)?
840            .send()
841            .await?;
842        self.check_response_and_extract_label(res)
843    }
844
845    pub async fn load(&self, label: String) -> Result<InserterInner> {
846        self.request_builder.build_txn_inserter(label).await
847    }
848}