risingwave_connector/sink/
starrocks.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use serde_with::{DisplayFromStr, serde_as};
30use thiserror_ext::AsReport;
31use url::form_urlencoded;
32use with_options::WithOptions;
33
34use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
35use super::doris_starrocks_connector::{
36    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
37    StarrocksTxnRequestBuilder,
38};
39use super::encoder::{JsonEncoder, RowEncoder};
40use super::{
41    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
42    SinkWriterMetrics,
43};
44use crate::enforce_secret::EnforceSecret;
45use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
46use crate::sink::{Result, Sink, SinkWriter, SinkWriterParam};
47
48pub const STARROCKS_SINK: &str = "starrocks";
49const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
50const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
51const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
52
53pub const fn _default_stream_load_http_timeout_ms() -> u64 {
54    30 * 1000
55}
56
57const fn default_use_https() -> bool {
58    false
59}
60
61#[serde_as]
62#[derive(Deserialize, Debug, Clone, WithOptions)]
63pub struct StarrocksCommon {
64    /// The `StarRocks` host address.
65    #[serde(rename = "starrocks.host")]
66    pub host: String,
67    /// The port to the MySQL server of `StarRocks` FE.
68    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
69    pub mysql_port: String,
70    /// The port to the HTTP server of `StarRocks` FE.
71    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
72    pub http_port: String,
73    /// The user name used to access the `StarRocks` database.
74    #[serde(rename = "starrocks.user")]
75    pub user: String,
76    /// The password associated with the user.
77    #[serde(rename = "starrocks.password")]
78    pub password: String,
79    /// The `StarRocks` database where the target table is located
80    #[serde(rename = "starrocks.database")]
81    pub database: String,
82    /// The `StarRocks` table you want to sink data to.
83    #[serde(rename = "starrocks.table")]
84    pub table: String,
85
86    /// Whether to use https to connect to the `StarRocks` server.
87    #[serde(rename = "starrocks.use_https")]
88    #[serde(default = "default_use_https")]
89    #[serde_as(as = "DisplayFromStr")]
90    pub use_https: bool,
91}
92
93impl EnforceSecret for StarrocksCommon {
94    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
95        "starrocks.password", "starrocks.user"
96    };
97}
98
99#[serde_as]
100#[derive(Clone, Debug, Deserialize, WithOptions)]
101pub struct StarrocksConfig {
102    #[serde(flatten)]
103    pub common: StarrocksCommon,
104
105    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
106    #[serde(
107        rename = "starrocks.stream_load.http.timeout.ms",
108        default = "_default_stream_load_http_timeout_ms"
109    )]
110    #[serde_as(as = "DisplayFromStr")]
111    #[with_option(allow_alter_on_fly)]
112    pub stream_load_http_timeout_ms: u64,
113
114    /// Set this option to a positive integer n, RisingWave will try to commit data
115    /// to Starrocks at every n checkpoints by leveraging the
116    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
117    /// also, in this time, the `sink_decouple` option should be enabled as well.
118    /// Defaults to 10 if `commit_checkpoint_interval` <= 0
119    #[serde(default = "default_commit_checkpoint_interval")]
120    #[serde_as(as = "DisplayFromStr")]
121    #[with_option(allow_alter_on_fly)]
122    pub commit_checkpoint_interval: u64,
123
124    /// Enable partial update
125    #[serde(rename = "starrocks.partial_update")]
126    pub partial_update: Option<String>,
127
128    pub r#type: String, // accept "append-only" or "upsert"
129}
130
131impl EnforceSecret for StarrocksConfig {
132    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
133        StarrocksCommon::enforce_one(prop)
134    }
135}
136
137fn default_commit_checkpoint_interval() -> u64 {
138    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
139}
140
141impl StarrocksConfig {
142    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
143        let config =
144            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
145                .map_err(|e| SinkError::Config(anyhow!(e)))?;
146        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
147            return Err(SinkError::Config(anyhow!(
148                "`{}` must be {}, or {}",
149                SINK_TYPE_OPTION,
150                SINK_TYPE_APPEND_ONLY,
151                SINK_TYPE_UPSERT
152            )));
153        }
154        if config.commit_checkpoint_interval == 0 {
155            return Err(SinkError::Config(anyhow!(
156                "`commit_checkpoint_interval` must be greater than 0"
157            )));
158        }
159        Ok(config)
160    }
161}
162
163#[derive(Debug)]
164pub struct StarrocksSink {
165    pub config: StarrocksConfig,
166    schema: Schema,
167    pk_indices: Vec<usize>,
168    is_append_only: bool,
169}
170
171impl EnforceSecret for StarrocksSink {
172    fn enforce_secret<'a>(
173        prop_iter: impl Iterator<Item = &'a str>,
174    ) -> crate::error::ConnectorResult<()> {
175        for prop in prop_iter {
176            StarrocksConfig::enforce_one(prop)?;
177        }
178        Ok(())
179    }
180}
181
182impl StarrocksSink {
183    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
184        let pk_indices = param.downstream_pk.clone();
185        let is_append_only = param.sink_type.is_append_only();
186        Ok(Self {
187            config,
188            schema,
189            pk_indices,
190            is_append_only,
191        })
192    }
193}
194
195impl StarrocksSink {
196    fn check_column_name_and_type(
197        &self,
198        starrocks_columns_desc: HashMap<String, String>,
199    ) -> Result<()> {
200        let rw_fields_name = self.schema.fields();
201        if rw_fields_name.len() > starrocks_columns_desc.len() {
202            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
203        }
204
205        for i in rw_fields_name {
206            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
207                SinkError::Starrocks(format!(
208                    "Column name don't find in starrocks, risingwave is {:?} ",
209                    i.name
210                ))
211            })?;
212            if !Self::check_and_correct_column_type(&i.data_type, value)? {
213                return Err(SinkError::Starrocks(format!(
214                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
215                    i.name, value, i.data_type
216                )));
217            }
218        }
219        Ok(())
220    }
221
222    fn check_and_correct_column_type(
223        rw_data_type: &DataType,
224        starrocks_data_type: &String,
225    ) -> Result<bool> {
226        match rw_data_type {
227            risingwave_common::types::DataType::Boolean => {
228                Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
229            }
230            risingwave_common::types::DataType::Int16 => {
231                Ok(starrocks_data_type.contains("smallint"))
232            }
233            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
234            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
235            risingwave_common::types::DataType::Float32 => {
236                Ok(starrocks_data_type.contains("float"))
237            }
238            risingwave_common::types::DataType::Float64 => {
239                Ok(starrocks_data_type.contains("double"))
240            }
241            risingwave_common::types::DataType::Decimal => {
242                Ok(starrocks_data_type.contains("decimal"))
243            }
244            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
245            risingwave_common::types::DataType::Varchar => {
246                Ok(starrocks_data_type.contains("varchar"))
247            }
248            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
249                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
250            )),
251            risingwave_common::types::DataType::Timestamp => {
252                Ok(starrocks_data_type.contains("datetime"))
253            }
254            risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
255                "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
256            )),
257            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
258                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
259            )),
260            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
261                "STRUCT is not supported for Starrocks sink.".to_owned(),
262            )),
263            risingwave_common::types::DataType::List(list) => {
264                // For compatibility with older versions starrocks
265                if starrocks_data_type.contains("unknown") {
266                    return Ok(true);
267                }
268                let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
269                Ok(check_result && starrocks_data_type.contains("array"))
270            }
271            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
272                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
273            )),
274            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
275            risingwave_common::types::DataType::Serial => {
276                Ok(starrocks_data_type.contains("bigint"))
277            }
278            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
279                "INT256 is not supported for Starrocks sink.".to_owned(),
280            )),
281            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
282                "MAP is not supported for Starrocks sink.".to_owned(),
283            )),
284            DataType::Vector(_) => Err(SinkError::Starrocks(
285                "VECTOR is not supported for Starrocks sink.".to_owned(),
286            )),
287        }
288    }
289}
290
291impl Sink for StarrocksSink {
292    type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
293
294    const SINK_NAME: &'static str = STARROCKS_SINK;
295
296    async fn validate(&self) -> Result<()> {
297        if !self.is_append_only && self.pk_indices.is_empty() {
298            return Err(SinkError::Config(anyhow!(
299                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
300            )));
301        }
302        // check reachability
303        let mut client = StarrocksSchemaClient::new(
304            self.config.common.host.clone(),
305            self.config.common.mysql_port.clone(),
306            self.config.common.table.clone(),
307            self.config.common.database.clone(),
308            self.config.common.user.clone(),
309            self.config.common.password.clone(),
310        )
311        .await?;
312        let (read_model, pks) = client.get_pk_from_starrocks().await?;
313
314        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
315            return Err(SinkError::Config(anyhow!(
316                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
317            )));
318        }
319
320        for (index, filed) in self.schema.fields().iter().enumerate() {
321            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
322                return Err(SinkError::Starrocks(format!(
323                    "Can't find pk {:?} in starrocks",
324                    filed.name
325                )));
326            }
327        }
328
329        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
330
331        self.check_column_name_and_type(starrocks_columns_desc)?;
332        Ok(())
333    }
334
335    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
336        StarrocksConfig::from_btreemap(config.clone())?;
337        Ok(())
338    }
339
340    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
341        let commit_checkpoint_interval =
342            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
343                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
344            );
345
346        let writer = StarrocksSinkWriter::new(
347            self.config.clone(),
348            self.schema.clone(),
349            self.pk_indices.clone(),
350            self.is_append_only,
351            writer_param.executor_id,
352        )?;
353
354        let metrics = SinkWriterMetrics::new(&writer_param);
355
356        Ok(DecoupleCheckpointLogSinkerOf::new(
357            writer,
358            metrics,
359            commit_checkpoint_interval,
360        ))
361    }
362}
363
364pub struct StarrocksSinkWriter {
365    pub config: StarrocksConfig,
366    #[expect(dead_code)]
367    schema: Schema,
368    #[expect(dead_code)]
369    pk_indices: Vec<usize>,
370    is_append_only: bool,
371    client: Option<StarrocksClient>,
372    txn_client: Arc<StarrocksTxnClient>,
373    row_encoder: JsonEncoder,
374    executor_id: u64,
375    curr_txn_label: Option<String>,
376}
377
378impl TryFrom<SinkParam> for StarrocksSink {
379    type Error = SinkError;
380
381    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
382        let schema = param.schema();
383        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
384        StarrocksSink::new(param, config, schema)
385    }
386}
387
388impl StarrocksSinkWriter {
389    pub fn new(
390        config: StarrocksConfig,
391        schema: Schema,
392        pk_indices: Vec<usize>,
393        is_append_only: bool,
394        executor_id: u64,
395    ) -> Result<Self> {
396        let mut field_names = schema.names_str();
397        if !is_append_only {
398            field_names.push(STARROCKS_DELETE_SIGN);
399        };
400        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
401        // a field name being a reserved word. For example, `order`, 'from`, etc.
402        let field_names = field_names
403            .into_iter()
404            .map(|name| format!("`{}`", name))
405            .collect::<Vec<String>>();
406        let field_names_str = field_names
407            .iter()
408            .map(|name| name.as_str())
409            .collect::<Vec<&str>>();
410
411        let header = HeaderBuilder::new()
412            .add_common_header()
413            .set_user_password(config.common.user.clone(), config.common.password.clone())
414            .add_json_format()
415            .set_partial_update(config.partial_update.clone())
416            .set_columns_name(field_names_str)
417            .set_db(config.common.database.clone())
418            .set_table(config.common.table.clone())
419            .build();
420
421        let url = if config.common.use_https {
422            format!("https://{}:{}", config.common.host, config.common.http_port)
423        } else {
424            format!("http://{}:{}", config.common.host, config.common.http_port)
425        };
426        let txn_request_builder =
427            StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
428
429        Ok(Self {
430            config,
431            schema: schema.clone(),
432            pk_indices,
433            is_append_only,
434            client: None,
435            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
436            row_encoder: JsonEncoder::new_with_starrocks(schema, None),
437            executor_id,
438            curr_txn_label: None,
439        })
440    }
441
442    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
443        for (op, row) in chunk.rows() {
444            if op != Op::Insert {
445                continue;
446            }
447            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
448            self.client
449                .as_mut()
450                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
451                .write(row_json_string.into())
452                .await?;
453        }
454        Ok(())
455    }
456
457    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
458        for (op, row) in chunk.rows() {
459            match op {
460                Op::Insert => {
461                    let mut row_json_value = self.row_encoder.encode(row)?;
462                    row_json_value.insert(
463                        STARROCKS_DELETE_SIGN.to_owned(),
464                        Value::String("0".to_owned()),
465                    );
466                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
467                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
468                    })?;
469                    self.client
470                        .as_mut()
471                        .ok_or_else(|| {
472                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
473                        })?
474                        .write(row_json_string.into())
475                        .await?;
476                }
477                Op::Delete => {
478                    let mut row_json_value = self.row_encoder.encode(row)?;
479                    row_json_value.insert(
480                        STARROCKS_DELETE_SIGN.to_owned(),
481                        Value::String("1".to_owned()),
482                    );
483                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
484                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
485                    })?;
486                    self.client
487                        .as_mut()
488                        .ok_or_else(|| {
489                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
490                        })?
491                        .write(row_json_string.into())
492                        .await?;
493                }
494                Op::UpdateDelete => {}
495                Op::UpdateInsert => {
496                    let mut row_json_value = self.row_encoder.encode(row)?;
497                    row_json_value.insert(
498                        STARROCKS_DELETE_SIGN.to_owned(),
499                        Value::String("0".to_owned()),
500                    );
501                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
502                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
503                    })?;
504                    self.client
505                        .as_mut()
506                        .ok_or_else(|| {
507                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
508                        })?
509                        .write(row_json_string.into())
510                        .await?;
511                }
512            }
513        }
514        Ok(())
515    }
516
517    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
518    #[inline(always)]
519    fn new_txn_label(&self) -> String {
520        format!(
521            "rw-txn-{}-{}",
522            self.executor_id,
523            chrono::Utc::now().timestamp_micros()
524        )
525    }
526
527    async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
528        tracing::debug!(?txn_label, "prepare transaction");
529        let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
530        if txn_label != txn_label_res {
531            return Err(SinkError::Starrocks(format!(
532                "label {} returned from prepare transaction {} differs from the current one",
533                txn_label, txn_label_res
534            )));
535        }
536        tracing::debug!(?txn_label, "commit transaction");
537        let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
538        if txn_label != txn_label_res {
539            return Err(SinkError::Starrocks(format!(
540                "label {} returned from commit transaction {} differs from the current one",
541                txn_label, txn_label_res
542            )));
543        }
544        Ok(())
545    }
546}
547
548impl Drop for StarrocksSinkWriter {
549    fn drop(&mut self) {
550        if let Some(txn_label) = self.curr_txn_label.take() {
551            let txn_client = self.txn_client.clone();
552            tokio::spawn(async move {
553                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
554                    tracing::error!(
555                        "starrocks rollback transaction error: {:?}, txn label: {}",
556                        e.as_report(),
557                        txn_label
558                    );
559                }
560            });
561        }
562    }
563}
564
565#[async_trait]
566impl SinkWriter for StarrocksSinkWriter {
567    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
568        Ok(())
569    }
570
571    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
572        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
573        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
574        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
575        if self.curr_txn_label.is_none() {
576            let txn_label = self.new_txn_label();
577            tracing::debug!(?txn_label, "begin transaction");
578            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
579            if txn_label != txn_label_res {
580                return Err(SinkError::Starrocks(format!(
581                    "label {} returned from StarRocks {} differs from generated one",
582                    txn_label, txn_label_res
583                )));
584            }
585            self.curr_txn_label = Some(txn_label.clone());
586        }
587        if self.client.is_none() {
588            let txn_label = self.curr_txn_label.clone();
589            self.client = Some(StarrocksClient::new(
590                self.txn_client.load(txn_label.unwrap()).await?,
591            ));
592        }
593        if self.is_append_only {
594            self.append_only(chunk).await
595        } else {
596            self.upsert(chunk).await
597        }
598    }
599
600    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
601        if let Some(client) = self.client.take() {
602            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
603            // one or more load requests should be made within one commit_checkpoint_interval period.
604            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
605            // Thus, only one version will be produced when the transaction is committed. See Stream Load
606            // transaction interface for more information.
607            client.finish().await?;
608        }
609
610        if is_checkpoint
611            && let Some(txn_label) = self.curr_txn_label.take()
612            && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
613        {
614            match self.txn_client.rollback(txn_label.clone()).await {
615                Ok(_) => tracing::warn!(
616                    ?txn_label,
617                    "transaction is successfully rolled back due to commit failure"
618                ),
619                Err(err) => {
620                    tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
621                }
622            }
623
624            return Err(err);
625        }
626        Ok(())
627    }
628
629    async fn abort(&mut self) -> Result<()> {
630        if let Some(txn_label) = self.curr_txn_label.take() {
631            tracing::debug!(?txn_label, "rollback transaction");
632            self.txn_client.rollback(txn_label).await?;
633        }
634        Ok(())
635    }
636}
637
638pub struct StarrocksSchemaClient {
639    table: String,
640    db: String,
641    conn: mysql_async::Conn,
642}
643
644impl StarrocksSchemaClient {
645    pub async fn new(
646        host: String,
647        port: String,
648        table: String,
649        db: String,
650        user: String,
651        password: String,
652    ) -> Result<Self> {
653        // username & password may contain special chars, so we need to do URL encoding on them.
654        // Otherwise, Opts::from_url may report a `Parse error`
655        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
656        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
657
658        let conn_uri = format!(
659            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
660            user,
661            password,
662            host,
663            port,
664            db,
665            STARROCK_MYSQL_PREFER_SOCKET,
666            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
667            STARROCK_MYSQL_WAIT_TIMEOUT
668        );
669        let pool = mysql_async::Pool::new(
670            Opts::from_url(&conn_uri)
671                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
672        );
673        let conn = pool
674            .get_conn()
675            .await
676            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
677
678        Ok(Self { table, db, conn })
679    }
680
681    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
682        let query = format!(
683            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
684            self.table, self.db
685        );
686        let mut query_map: HashMap<String, String> = HashMap::default();
687        self.conn
688            .query_map(query, |(column_name, column_type)| {
689                query_map.insert(column_name, column_type)
690            })
691            .await
692            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
693        Ok(query_map)
694    }
695
696    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
697        let query = format!(
698            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
699            self.table, self.db
700        );
701        let table_mode_pk: (String, String) = self
702            .conn
703            .query_map(
704                query,
705                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
706                    .as_str()
707                {
708                    // Get primary key of aggregate table from the sort_key field
709                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
710                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
711                    "AGG_KEYS" => (table_model, sort_key),
712                    _ => (table_model, primary_key),
713                },
714            )
715            .await
716            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
717            .first()
718            .ok_or_else(|| {
719                SinkError::Starrocks(format!(
720                    "Can't find schema with table {:?} and database {:?}",
721                    self.table, self.db
722                ))
723            })?
724            .clone();
725        Ok(table_mode_pk)
726    }
727}
728
729#[derive(Debug, Serialize, Deserialize)]
730pub struct StarrocksInsertResultResponse {
731    #[serde(rename = "TxnId")]
732    pub txn_id: Option<i64>,
733    #[serde(rename = "Seq")]
734    pub seq: Option<i64>,
735    #[serde(rename = "Label")]
736    pub label: Option<String>,
737    #[serde(rename = "Status")]
738    pub status: String,
739    #[serde(rename = "Message")]
740    pub message: String,
741    #[serde(rename = "NumberTotalRows")]
742    pub number_total_rows: Option<i64>,
743    #[serde(rename = "NumberLoadedRows")]
744    pub number_loaded_rows: Option<i64>,
745    #[serde(rename = "NumberFilteredRows")]
746    pub number_filtered_rows: Option<i32>,
747    #[serde(rename = "NumberUnselectedRows")]
748    pub number_unselected_rows: Option<i32>,
749    #[serde(rename = "LoadBytes")]
750    pub load_bytes: Option<i64>,
751    #[serde(rename = "LoadTimeMs")]
752    pub load_time_ms: Option<i32>,
753    #[serde(rename = "BeginTxnTimeMs")]
754    pub begin_txn_time_ms: Option<i32>,
755    #[serde(rename = "ReadDataTimeMs")]
756    pub read_data_time_ms: Option<i32>,
757    #[serde(rename = "WriteDataTimeMs")]
758    pub write_data_time_ms: Option<i32>,
759    #[serde(rename = "CommitAndPublishTimeMs")]
760    pub commit_and_publish_time_ms: Option<i32>,
761    #[serde(rename = "StreamLoadPlanTimeMs")]
762    pub stream_load_plan_time_ms: Option<i32>,
763    #[serde(rename = "ExistingJobStatus")]
764    pub existing_job_status: Option<String>,
765    #[serde(rename = "ErrorURL")]
766    pub error_url: Option<String>,
767}
768
769pub struct StarrocksClient {
770    insert: InserterInner,
771}
772impl StarrocksClient {
773    pub fn new(insert: InserterInner) -> Self {
774        Self { insert }
775    }
776
777    pub async fn write(&mut self, data: Bytes) -> Result<()> {
778        self.insert.write(data).await?;
779        Ok(())
780    }
781
782    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
783        let raw = self.insert.finish().await?;
784        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
785            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
786
787        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
788            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
789                "Insert error: {}, {}, {:?}",
790                res.status,
791                res.message,
792                res.error_url,
793            )));
794        };
795        Ok(res)
796    }
797}
798
799pub struct StarrocksTxnClient {
800    request_builder: StarrocksTxnRequestBuilder,
801}
802
803impl StarrocksTxnClient {
804    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
805        Self { request_builder }
806    }
807
808    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
809        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
810            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
811        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
812            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
813                "transaction error: {}, {}, {:?}",
814                res.status,
815                res.message,
816                res.error_url,
817            )));
818        }
819        res.label.ok_or_else(|| {
820            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
821        })
822    }
823
824    pub async fn begin(&self, label: String) -> Result<String> {
825        let res = self
826            .request_builder
827            .build_begin_request_sender(label)?
828            .send()
829            .await?;
830        self.check_response_and_extract_label(res)
831    }
832
833    pub async fn prepare(&self, label: String) -> Result<String> {
834        let res = self
835            .request_builder
836            .build_prepare_request_sender(label)?
837            .send()
838            .await?;
839        self.check_response_and_extract_label(res)
840    }
841
842    pub async fn commit(&self, label: String) -> Result<String> {
843        let res = self
844            .request_builder
845            .build_commit_request_sender(label)?
846            .send()
847            .await?;
848        self.check_response_and_extract_label(res)
849    }
850
851    pub async fn rollback(&self, label: String) -> Result<String> {
852        let res = self
853            .request_builder
854            .build_rollback_request_sender(label)?
855            .send()
856            .await?;
857        self.check_response_and_extract_label(res)
858    }
859
860    pub async fn load(&self, label: String) -> Result<InserterInner> {
861        self.request_builder.build_txn_inserter(label).await
862    }
863}