risingwave_connector/sink/
starrocks.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use serde_with::{DisplayFromStr, serde_as};
30use thiserror_ext::AsReport;
31use url::form_urlencoded;
32use with_options::WithOptions;
33
34use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
35use super::doris_starrocks_connector::{
36    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
37    StarrocksTxnRequestBuilder,
38};
39use super::encoder::{JsonEncoder, RowEncoder};
40use super::{
41    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
42    SinkWriterMetrics,
43};
44use crate::enforce_secret::EnforceSecret;
45use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
46use crate::sink::{Result, Sink, SinkWriter, SinkWriterParam};
47
48pub const STARROCKS_SINK: &str = "starrocks";
49const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
50const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
51const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
52
53const fn _default_stream_load_http_timeout_ms() -> u64 {
54    30 * 1000
55}
56
57const fn default_use_https() -> bool {
58    false
59}
60
61#[serde_as]
62#[derive(Deserialize, Debug, Clone, WithOptions)]
63pub struct StarrocksCommon {
64    /// The `StarRocks` host address.
65    #[serde(rename = "starrocks.host")]
66    pub host: String,
67    /// The port to the MySQL server of `StarRocks` FE.
68    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
69    pub mysql_port: String,
70    /// The port to the HTTP server of `StarRocks` FE.
71    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
72    pub http_port: String,
73    /// The user name used to access the `StarRocks` database.
74    #[serde(rename = "starrocks.user")]
75    pub user: String,
76    /// The password associated with the user.
77    #[serde(rename = "starrocks.password")]
78    pub password: String,
79    /// The `StarRocks` database where the target table is located
80    #[serde(rename = "starrocks.database")]
81    pub database: String,
82    /// The `StarRocks` table you want to sink data to.
83    #[serde(rename = "starrocks.table")]
84    pub table: String,
85
86    /// Whether to use https to connect to the `StarRocks` server.
87    #[serde(rename = "starrocks.use_https")]
88    #[serde(default = "default_use_https")]
89    #[serde_as(as = "DisplayFromStr")]
90    pub use_https: bool,
91}
92
93impl EnforceSecret for StarrocksCommon {
94    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
95        "starrocks.password", "starrocks.user"
96    };
97}
98
99#[serde_as]
100#[derive(Clone, Debug, Deserialize, WithOptions)]
101pub struct StarrocksConfig {
102    #[serde(flatten)]
103    pub common: StarrocksCommon,
104
105    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
106    #[serde(
107        rename = "starrocks.stream_load.http.timeout.ms",
108        default = "_default_stream_load_http_timeout_ms"
109    )]
110    #[serde_as(as = "DisplayFromStr")]
111    pub stream_load_http_timeout_ms: u64,
112
113    /// Set this option to a positive integer n, RisingWave will try to commit data
114    /// to Starrocks at every n checkpoints by leveraging the
115    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
116    /// also, in this time, the `sink_decouple` option should be enabled as well.
117    /// Defaults to 10 if `commit_checkpoint_interval` <= 0
118    #[serde(default = "default_commit_checkpoint_interval")]
119    #[serde_as(as = "DisplayFromStr")]
120    #[with_option(allow_alter_on_fly)]
121    pub commit_checkpoint_interval: u64,
122
123    /// Enable partial update
124    #[serde(rename = "starrocks.partial_update")]
125    pub partial_update: Option<String>,
126
127    pub r#type: String, // accept "append-only" or "upsert"
128}
129
130impl EnforceSecret for StarrocksConfig {
131    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
132        StarrocksCommon::enforce_one(prop)
133    }
134}
135
136fn default_commit_checkpoint_interval() -> u64 {
137    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
138}
139
140impl StarrocksConfig {
141    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
142        let config =
143            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
144                .map_err(|e| SinkError::Config(anyhow!(e)))?;
145        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
146            return Err(SinkError::Config(anyhow!(
147                "`{}` must be {}, or {}",
148                SINK_TYPE_OPTION,
149                SINK_TYPE_APPEND_ONLY,
150                SINK_TYPE_UPSERT
151            )));
152        }
153        if config.commit_checkpoint_interval == 0 {
154            return Err(SinkError::Config(anyhow!(
155                "`commit_checkpoint_interval` must be greater than 0"
156            )));
157        }
158        Ok(config)
159    }
160}
161
162#[derive(Debug)]
163pub struct StarrocksSink {
164    pub config: StarrocksConfig,
165    schema: Schema,
166    pk_indices: Vec<usize>,
167    is_append_only: bool,
168}
169
170impl EnforceSecret for StarrocksSink {
171    fn enforce_secret<'a>(
172        prop_iter: impl Iterator<Item = &'a str>,
173    ) -> crate::error::ConnectorResult<()> {
174        for prop in prop_iter {
175            StarrocksConfig::enforce_one(prop)?;
176        }
177        Ok(())
178    }
179}
180
181impl StarrocksSink {
182    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
183        let pk_indices = param.downstream_pk.clone();
184        let is_append_only = param.sink_type.is_append_only();
185        Ok(Self {
186            config,
187            schema,
188            pk_indices,
189            is_append_only,
190        })
191    }
192}
193
194impl StarrocksSink {
195    fn check_column_name_and_type(
196        &self,
197        starrocks_columns_desc: HashMap<String, String>,
198    ) -> Result<()> {
199        let rw_fields_name = self.schema.fields();
200        if rw_fields_name.len() > starrocks_columns_desc.len() {
201            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
202        }
203
204        for i in rw_fields_name {
205            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
206                SinkError::Starrocks(format!(
207                    "Column name don't find in starrocks, risingwave is {:?} ",
208                    i.name
209                ))
210            })?;
211            if !Self::check_and_correct_column_type(&i.data_type, value)? {
212                return Err(SinkError::Starrocks(format!(
213                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
214                    i.name, value, i.data_type
215                )));
216            }
217        }
218        Ok(())
219    }
220
221    fn check_and_correct_column_type(
222        rw_data_type: &DataType,
223        starrocks_data_type: &String,
224    ) -> Result<bool> {
225        match rw_data_type {
226            risingwave_common::types::DataType::Boolean => {
227                Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
228            }
229            risingwave_common::types::DataType::Int16 => {
230                Ok(starrocks_data_type.contains("smallint"))
231            }
232            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
233            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
234            risingwave_common::types::DataType::Float32 => {
235                Ok(starrocks_data_type.contains("float"))
236            }
237            risingwave_common::types::DataType::Float64 => {
238                Ok(starrocks_data_type.contains("double"))
239            }
240            risingwave_common::types::DataType::Decimal => {
241                Ok(starrocks_data_type.contains("decimal"))
242            }
243            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
244            risingwave_common::types::DataType::Varchar => {
245                Ok(starrocks_data_type.contains("varchar"))
246            }
247            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
248                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
249            )),
250            risingwave_common::types::DataType::Timestamp => {
251                Ok(starrocks_data_type.contains("datetime"))
252            }
253            risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
254                "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
255            )),
256            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
257                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
258            )),
259            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
260                "STRUCT is not supported for Starrocks sink.".to_owned(),
261            )),
262            risingwave_common::types::DataType::List(list) => {
263                // For compatibility with older versions starrocks
264                if starrocks_data_type.contains("unknown") {
265                    return Ok(true);
266                }
267                let check_result = Self::check_and_correct_column_type(list.as_ref(), starrocks_data_type)?;
268                Ok(check_result && starrocks_data_type.contains("array"))
269            }
270            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
271                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
272            )),
273            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
274            risingwave_common::types::DataType::Serial => {
275                Ok(starrocks_data_type.contains("bigint"))
276            }
277            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
278                "INT256 is not supported for Starrocks sink.".to_owned(),
279            )),
280            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
281                "MAP is not supported for Starrocks sink.".to_owned(),
282            )),
283            DataType::Vector(_) => Err(SinkError::Starrocks(
284                "VECTOR is not supported for Starrocks sink.".to_owned(),
285            )),
286        }
287    }
288}
289
290impl Sink for StarrocksSink {
291    type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
292
293    const SINK_NAME: &'static str = STARROCKS_SINK;
294
295    async fn validate(&self) -> Result<()> {
296        if !self.is_append_only && self.pk_indices.is_empty() {
297            return Err(SinkError::Config(anyhow!(
298                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
299            )));
300        }
301        // check reachability
302        let mut client = StarrocksSchemaClient::new(
303            self.config.common.host.clone(),
304            self.config.common.mysql_port.clone(),
305            self.config.common.table.clone(),
306            self.config.common.database.clone(),
307            self.config.common.user.clone(),
308            self.config.common.password.clone(),
309        )
310        .await?;
311        let (read_model, pks) = client.get_pk_from_starrocks().await?;
312
313        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
314            return Err(SinkError::Config(anyhow!(
315                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
316            )));
317        }
318
319        for (index, filed) in self.schema.fields().iter().enumerate() {
320            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
321                return Err(SinkError::Starrocks(format!(
322                    "Can't find pk {:?} in starrocks",
323                    filed.name
324                )));
325            }
326        }
327
328        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
329
330        self.check_column_name_and_type(starrocks_columns_desc)?;
331        Ok(())
332    }
333
334    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
335        StarrocksConfig::from_btreemap(config.clone())?;
336        Ok(())
337    }
338
339    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
340        let commit_checkpoint_interval =
341            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
342                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
343            );
344
345        let writer = StarrocksSinkWriter::new(
346            self.config.clone(),
347            self.schema.clone(),
348            self.pk_indices.clone(),
349            self.is_append_only,
350            writer_param.executor_id,
351        )?;
352
353        let metrics = SinkWriterMetrics::new(&writer_param);
354
355        Ok(DecoupleCheckpointLogSinkerOf::new(
356            writer,
357            metrics,
358            commit_checkpoint_interval,
359        ))
360    }
361}
362
363pub struct StarrocksSinkWriter {
364    pub config: StarrocksConfig,
365    #[expect(dead_code)]
366    schema: Schema,
367    #[expect(dead_code)]
368    pk_indices: Vec<usize>,
369    is_append_only: bool,
370    client: Option<StarrocksClient>,
371    txn_client: Arc<StarrocksTxnClient>,
372    row_encoder: JsonEncoder,
373    executor_id: u64,
374    curr_txn_label: Option<String>,
375}
376
377impl TryFrom<SinkParam> for StarrocksSink {
378    type Error = SinkError;
379
380    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
381        let schema = param.schema();
382        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
383        StarrocksSink::new(param, config, schema)
384    }
385}
386
387impl StarrocksSinkWriter {
388    pub fn new(
389        config: StarrocksConfig,
390        schema: Schema,
391        pk_indices: Vec<usize>,
392        is_append_only: bool,
393        executor_id: u64,
394    ) -> Result<Self> {
395        let mut field_names = schema.names_str();
396        if !is_append_only {
397            field_names.push(STARROCKS_DELETE_SIGN);
398        };
399        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
400        // a field name being a reserved word. For example, `order`, 'from`, etc.
401        let field_names = field_names
402            .into_iter()
403            .map(|name| format!("`{}`", name))
404            .collect::<Vec<String>>();
405        let field_names_str = field_names
406            .iter()
407            .map(|name| name.as_str())
408            .collect::<Vec<&str>>();
409
410        let header = HeaderBuilder::new()
411            .add_common_header()
412            .set_user_password(config.common.user.clone(), config.common.password.clone())
413            .add_json_format()
414            .set_partial_update(config.partial_update.clone())
415            .set_columns_name(field_names_str)
416            .set_db(config.common.database.clone())
417            .set_table(config.common.table.clone())
418            .build();
419
420        let url = if config.common.use_https {
421            format!("https://{}:{}", config.common.host, config.common.http_port)
422        } else {
423            format!("http://{}:{}", config.common.host, config.common.http_port)
424        };
425        let txn_request_builder =
426            StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
427
428        Ok(Self {
429            config,
430            schema: schema.clone(),
431            pk_indices,
432            is_append_only,
433            client: None,
434            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
435            row_encoder: JsonEncoder::new_with_starrocks(schema, None),
436            executor_id,
437            curr_txn_label: None,
438        })
439    }
440
441    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
442        for (op, row) in chunk.rows() {
443            if op != Op::Insert {
444                continue;
445            }
446            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
447            self.client
448                .as_mut()
449                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
450                .write(row_json_string.into())
451                .await?;
452        }
453        Ok(())
454    }
455
456    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
457        for (op, row) in chunk.rows() {
458            match op {
459                Op::Insert => {
460                    let mut row_json_value = self.row_encoder.encode(row)?;
461                    row_json_value.insert(
462                        STARROCKS_DELETE_SIGN.to_owned(),
463                        Value::String("0".to_owned()),
464                    );
465                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
466                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
467                    })?;
468                    self.client
469                        .as_mut()
470                        .ok_or_else(|| {
471                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
472                        })?
473                        .write(row_json_string.into())
474                        .await?;
475                }
476                Op::Delete => {
477                    let mut row_json_value = self.row_encoder.encode(row)?;
478                    row_json_value.insert(
479                        STARROCKS_DELETE_SIGN.to_owned(),
480                        Value::String("1".to_owned()),
481                    );
482                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
483                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
484                    })?;
485                    self.client
486                        .as_mut()
487                        .ok_or_else(|| {
488                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
489                        })?
490                        .write(row_json_string.into())
491                        .await?;
492                }
493                Op::UpdateDelete => {}
494                Op::UpdateInsert => {
495                    let mut row_json_value = self.row_encoder.encode(row)?;
496                    row_json_value.insert(
497                        STARROCKS_DELETE_SIGN.to_owned(),
498                        Value::String("0".to_owned()),
499                    );
500                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
501                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
502                    })?;
503                    self.client
504                        .as_mut()
505                        .ok_or_else(|| {
506                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
507                        })?
508                        .write(row_json_string.into())
509                        .await?;
510                }
511            }
512        }
513        Ok(())
514    }
515
516    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
517    #[inline(always)]
518    fn new_txn_label(&self) -> String {
519        format!(
520            "rw-txn-{}-{}",
521            self.executor_id,
522            chrono::Utc::now().timestamp_micros()
523        )
524    }
525
526    async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
527        tracing::debug!(?txn_label, "prepare transaction");
528        let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
529        if txn_label != txn_label_res {
530            return Err(SinkError::Starrocks(format!(
531                "label {} returned from prepare transaction {} differs from the current one",
532                txn_label, txn_label_res
533            )));
534        }
535        tracing::debug!(?txn_label, "commit transaction");
536        let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
537        if txn_label != txn_label_res {
538            return Err(SinkError::Starrocks(format!(
539                "label {} returned from commit transaction {} differs from the current one",
540                txn_label, txn_label_res
541            )));
542        }
543        Ok(())
544    }
545}
546
547impl Drop for StarrocksSinkWriter {
548    fn drop(&mut self) {
549        if let Some(txn_label) = self.curr_txn_label.take() {
550            let txn_client = self.txn_client.clone();
551            tokio::spawn(async move {
552                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
553                    tracing::error!(
554                        "starrocks rollback transaction error: {:?}, txn label: {}",
555                        e.as_report(),
556                        txn_label
557                    );
558                }
559            });
560        }
561    }
562}
563
564#[async_trait]
565impl SinkWriter for StarrocksSinkWriter {
566    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
567        Ok(())
568    }
569
570    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
571        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
572        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
573        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
574        if self.curr_txn_label.is_none() {
575            let txn_label = self.new_txn_label();
576            tracing::debug!(?txn_label, "begin transaction");
577            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
578            if txn_label != txn_label_res {
579                return Err(SinkError::Starrocks(format!(
580                    "label {} returned from StarRocks {} differs from generated one",
581                    txn_label, txn_label_res
582                )));
583            }
584            self.curr_txn_label = Some(txn_label.clone());
585        }
586        if self.client.is_none() {
587            let txn_label = self.curr_txn_label.clone();
588            self.client = Some(StarrocksClient::new(
589                self.txn_client.load(txn_label.unwrap()).await?,
590            ));
591        }
592        if self.is_append_only {
593            self.append_only(chunk).await
594        } else {
595            self.upsert(chunk).await
596        }
597    }
598
599    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
600        if let Some(client) = self.client.take() {
601            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
602            // one or more load requests should be made within one commit_checkpoint_interval period.
603            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
604            // Thus, only one version will be produced when the transaction is committed. See Stream Load
605            // transaction interface for more information.
606            client.finish().await?;
607        }
608
609        if is_checkpoint
610            && let Some(txn_label) = self.curr_txn_label.take()
611            && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
612        {
613            match self.txn_client.rollback(txn_label.clone()).await {
614                Ok(_) => tracing::warn!(
615                    ?txn_label,
616                    "transaction is successfully rolled back due to commit failure"
617                ),
618                Err(err) => {
619                    tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
620                }
621            }
622
623            return Err(err);
624        }
625        Ok(())
626    }
627
628    async fn abort(&mut self) -> Result<()> {
629        if let Some(txn_label) = self.curr_txn_label.take() {
630            tracing::debug!(?txn_label, "rollback transaction");
631            self.txn_client.rollback(txn_label).await?;
632        }
633        Ok(())
634    }
635}
636
637pub struct StarrocksSchemaClient {
638    table: String,
639    db: String,
640    conn: mysql_async::Conn,
641}
642
643impl StarrocksSchemaClient {
644    pub async fn new(
645        host: String,
646        port: String,
647        table: String,
648        db: String,
649        user: String,
650        password: String,
651    ) -> Result<Self> {
652        // username & password may contain special chars, so we need to do URL encoding on them.
653        // Otherwise, Opts::from_url may report a `Parse error`
654        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
655        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
656
657        let conn_uri = format!(
658            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
659            user,
660            password,
661            host,
662            port,
663            db,
664            STARROCK_MYSQL_PREFER_SOCKET,
665            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
666            STARROCK_MYSQL_WAIT_TIMEOUT
667        );
668        let pool = mysql_async::Pool::new(
669            Opts::from_url(&conn_uri)
670                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
671        );
672        let conn = pool
673            .get_conn()
674            .await
675            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
676
677        Ok(Self { table, db, conn })
678    }
679
680    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
681        let query = format!(
682            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
683            self.table, self.db
684        );
685        let mut query_map: HashMap<String, String> = HashMap::default();
686        self.conn
687            .query_map(query, |(column_name, column_type)| {
688                query_map.insert(column_name, column_type)
689            })
690            .await
691            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
692        Ok(query_map)
693    }
694
695    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
696        let query = format!(
697            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
698            self.table, self.db
699        );
700        let table_mode_pk: (String, String) = self
701            .conn
702            .query_map(
703                query,
704                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
705                    .as_str()
706                {
707                    // Get primary key of aggregate table from the sort_key field
708                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
709                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
710                    "AGG_KEYS" => (table_model, sort_key),
711                    _ => (table_model, primary_key),
712                },
713            )
714            .await
715            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
716            .first()
717            .ok_or_else(|| {
718                SinkError::Starrocks(format!(
719                    "Can't find schema with table {:?} and database {:?}",
720                    self.table, self.db
721                ))
722            })?
723            .clone();
724        Ok(table_mode_pk)
725    }
726}
727
728#[derive(Debug, Serialize, Deserialize)]
729pub struct StarrocksInsertResultResponse {
730    #[serde(rename = "TxnId")]
731    pub txn_id: Option<i64>,
732    #[serde(rename = "Seq")]
733    pub seq: Option<i64>,
734    #[serde(rename = "Label")]
735    pub label: Option<String>,
736    #[serde(rename = "Status")]
737    pub status: String,
738    #[serde(rename = "Message")]
739    pub message: String,
740    #[serde(rename = "NumberTotalRows")]
741    pub number_total_rows: Option<i64>,
742    #[serde(rename = "NumberLoadedRows")]
743    pub number_loaded_rows: Option<i64>,
744    #[serde(rename = "NumberFilteredRows")]
745    pub number_filtered_rows: Option<i32>,
746    #[serde(rename = "NumberUnselectedRows")]
747    pub number_unselected_rows: Option<i32>,
748    #[serde(rename = "LoadBytes")]
749    pub load_bytes: Option<i64>,
750    #[serde(rename = "LoadTimeMs")]
751    pub load_time_ms: Option<i32>,
752    #[serde(rename = "BeginTxnTimeMs")]
753    pub begin_txn_time_ms: Option<i32>,
754    #[serde(rename = "ReadDataTimeMs")]
755    pub read_data_time_ms: Option<i32>,
756    #[serde(rename = "WriteDataTimeMs")]
757    pub write_data_time_ms: Option<i32>,
758    #[serde(rename = "CommitAndPublishTimeMs")]
759    pub commit_and_publish_time_ms: Option<i32>,
760    #[serde(rename = "StreamLoadPlanTimeMs")]
761    pub stream_load_plan_time_ms: Option<i32>,
762    #[serde(rename = "ExistingJobStatus")]
763    pub existing_job_status: Option<String>,
764    #[serde(rename = "ErrorURL")]
765    pub error_url: Option<String>,
766}
767
768pub struct StarrocksClient {
769    insert: InserterInner,
770}
771impl StarrocksClient {
772    pub fn new(insert: InserterInner) -> Self {
773        Self { insert }
774    }
775
776    pub async fn write(&mut self, data: Bytes) -> Result<()> {
777        self.insert.write(data).await?;
778        Ok(())
779    }
780
781    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
782        let raw = self.insert.finish().await?;
783        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
784            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
785
786        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
787            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
788                "Insert error: {}, {}, {:?}",
789                res.status,
790                res.message,
791                res.error_url,
792            )));
793        };
794        Ok(res)
795    }
796}
797
798pub struct StarrocksTxnClient {
799    request_builder: StarrocksTxnRequestBuilder,
800}
801
802impl StarrocksTxnClient {
803    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
804        Self { request_builder }
805    }
806
807    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
808        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
809            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
810        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
811            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
812                "transaction error: {}, {}, {:?}",
813                res.status,
814                res.message,
815                res.error_url,
816            )));
817        }
818        res.label.ok_or_else(|| {
819            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
820        })
821    }
822
823    pub async fn begin(&self, label: String) -> Result<String> {
824        let res = self
825            .request_builder
826            .build_begin_request_sender(label)?
827            .send()
828            .await?;
829        self.check_response_and_extract_label(res)
830    }
831
832    pub async fn prepare(&self, label: String) -> Result<String> {
833        let res = self
834            .request_builder
835            .build_prepare_request_sender(label)?
836            .send()
837            .await?;
838        self.check_response_and_extract_label(res)
839    }
840
841    pub async fn commit(&self, label: String) -> Result<String> {
842        let res = self
843            .request_builder
844            .build_commit_request_sender(label)?
845            .send()
846            .await?;
847        self.check_response_and_extract_label(res)
848    }
849
850    pub async fn rollback(&self, label: String) -> Result<String> {
851        let res = self
852            .request_builder
853            .build_rollback_request_sender(label)?
854            .send()
855            .await?;
856        self.check_response_and_extract_label(res)
857    }
858
859    pub async fn load(&self, label: String) -> Result<InserterInner> {
860        self.request_builder.build_txn_inserter(label).await
861    }
862}