risingwave_connector/sink/
starrocks.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::Deserialize;
28use serde_derive::Serialize;
29use serde_json::Value;
30use serde_with::{DisplayFromStr, serde_as};
31use thiserror_ext::AsReport;
32use url::form_urlencoded;
33use with_options::WithOptions;
34
35use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
36use super::doris_starrocks_connector::{
37    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
38    StarrocksTxnRequestBuilder,
39};
40use super::encoder::{JsonEncoder, RowEncoder};
41use super::{
42    DummySinkCommitCoordinator, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
43    SinkError, SinkParam, SinkWriterMetrics,
44};
45use crate::enforce_secret::EnforceSecret;
46use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
47use crate::sink::{Result, Sink, SinkWriter, SinkWriterParam};
48
49pub const STARROCKS_SINK: &str = "starrocks";
50const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
51const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
52const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
53
54const fn _default_stream_load_http_timeout_ms() -> u64 {
55    30 * 1000
56}
57
58const fn default_use_https() -> bool {
59    false
60}
61
62#[derive(Deserialize, Debug, Clone, WithOptions)]
63pub struct StarrocksCommon {
64    /// The `StarRocks` host address.
65    #[serde(rename = "starrocks.host")]
66    pub host: String,
67    /// The port to the MySQL server of `StarRocks` FE.
68    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
69    pub mysql_port: String,
70    /// The port to the HTTP server of `StarRocks` FE.
71    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
72    pub http_port: String,
73    /// The user name used to access the `StarRocks` database.
74    #[serde(rename = "starrocks.user")]
75    pub user: String,
76    /// The password associated with the user.
77    #[serde(rename = "starrocks.password")]
78    pub password: String,
79    /// The `StarRocks` database where the target table is located
80    #[serde(rename = "starrocks.database")]
81    pub database: String,
82    /// The `StarRocks` table you want to sink data to.
83    #[serde(rename = "starrocks.table")]
84    pub table: String,
85
86    /// Whether to use https to connect to the `StarRocks` server.
87    #[serde(rename = "starrocks.use_https")]
88    #[serde(default = "default_use_https")]
89    pub use_https: bool,
90}
91
92impl EnforceSecret for StarrocksCommon {
93    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
94        "starrocks.password", "starrocks.user"
95    };
96}
97
98#[serde_as]
99#[derive(Clone, Debug, Deserialize, WithOptions)]
100pub struct StarrocksConfig {
101    #[serde(flatten)]
102    pub common: StarrocksCommon,
103
104    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
105    #[serde(
106        rename = "starrocks.stream_load.http.timeout.ms",
107        default = "_default_stream_load_http_timeout_ms"
108    )]
109    #[serde_as(as = "DisplayFromStr")]
110    pub stream_load_http_timeout_ms: u64,
111
112    /// Set this option to a positive integer n, RisingWave will try to commit data
113    /// to Starrocks at every n checkpoints by leveraging the
114    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
115    /// also, in this time, the `sink_decouple` option should be enabled as well.
116    /// Defaults to 10 if `commit_checkpoint_interval` <= 0
117    #[serde(default = "default_commit_checkpoint_interval")]
118    #[serde_as(as = "DisplayFromStr")]
119    #[with_option(allow_alter_on_fly)]
120    pub commit_checkpoint_interval: u64,
121
122    /// Enable partial update
123    #[serde(rename = "starrocks.partial_update")]
124    pub partial_update: Option<String>,
125
126    pub r#type: String, // accept "append-only" or "upsert"
127}
128
129impl EnforceSecret for StarrocksConfig {
130    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
131        StarrocksCommon::enforce_one(prop)
132    }
133}
134
135fn default_commit_checkpoint_interval() -> u64 {
136    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
137}
138
139impl StarrocksConfig {
140    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
141        let config =
142            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
143                .map_err(|e| SinkError::Config(anyhow!(e)))?;
144        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
145            return Err(SinkError::Config(anyhow!(
146                "`{}` must be {}, or {}",
147                SINK_TYPE_OPTION,
148                SINK_TYPE_APPEND_ONLY,
149                SINK_TYPE_UPSERT
150            )));
151        }
152        if config.commit_checkpoint_interval == 0 {
153            return Err(SinkError::Config(anyhow!(
154                "`commit_checkpoint_interval` must be greater than 0"
155            )));
156        }
157        Ok(config)
158    }
159}
160
161#[derive(Debug)]
162pub struct StarrocksSink {
163    pub config: StarrocksConfig,
164    schema: Schema,
165    pk_indices: Vec<usize>,
166    is_append_only: bool,
167}
168
169impl EnforceSecret for StarrocksSink {
170    fn enforce_secret<'a>(
171        prop_iter: impl Iterator<Item = &'a str>,
172    ) -> crate::error::ConnectorResult<()> {
173        for prop in prop_iter {
174            StarrocksConfig::enforce_one(prop)?;
175        }
176        Ok(())
177    }
178}
179
180impl StarrocksSink {
181    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
182        let pk_indices = param.downstream_pk.clone();
183        let is_append_only = param.sink_type.is_append_only();
184        Ok(Self {
185            config,
186            schema,
187            pk_indices,
188            is_append_only,
189        })
190    }
191}
192
193impl StarrocksSink {
194    fn check_column_name_and_type(
195        &self,
196        starrocks_columns_desc: HashMap<String, String>,
197    ) -> Result<()> {
198        let rw_fields_name = self.schema.fields();
199        if rw_fields_name.len() > starrocks_columns_desc.len() {
200            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
201        }
202
203        for i in rw_fields_name {
204            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
205                SinkError::Starrocks(format!(
206                    "Column name don't find in starrocks, risingwave is {:?} ",
207                    i.name
208                ))
209            })?;
210            if !Self::check_and_correct_column_type(&i.data_type, value)? {
211                return Err(SinkError::Starrocks(format!(
212                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
213                    i.name, value, i.data_type
214                )));
215            }
216        }
217        Ok(())
218    }
219
220    fn check_and_correct_column_type(
221        rw_data_type: &DataType,
222        starrocks_data_type: &String,
223    ) -> Result<bool> {
224        match rw_data_type {
225            risingwave_common::types::DataType::Boolean => {
226                Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
227            }
228            risingwave_common::types::DataType::Int16 => {
229                Ok(starrocks_data_type.contains("smallint"))
230            }
231            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
232            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
233            risingwave_common::types::DataType::Float32 => {
234                Ok(starrocks_data_type.contains("float"))
235            }
236            risingwave_common::types::DataType::Float64 => {
237                Ok(starrocks_data_type.contains("double"))
238            }
239            risingwave_common::types::DataType::Decimal => {
240                Ok(starrocks_data_type.contains("decimal"))
241            }
242            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
243            risingwave_common::types::DataType::Varchar => {
244                Ok(starrocks_data_type.contains("varchar"))
245            }
246            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
247                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
248            )),
249            risingwave_common::types::DataType::Timestamp => {
250                Ok(starrocks_data_type.contains("datetime"))
251            }
252            risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
253                "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
254            )),
255            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
256                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
257            )),
258            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
259                "STRUCT is not supported for Starrocks sink.".to_owned(),
260            )),
261            risingwave_common::types::DataType::List(list) => {
262                // For compatibility with older versions starrocks
263                if starrocks_data_type.contains("unknown") {
264                    return Ok(true);
265                }
266                let check_result = Self::check_and_correct_column_type(list.as_ref(), starrocks_data_type)?;
267                Ok(check_result && starrocks_data_type.contains("array"))
268            }
269            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
270                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
271            )),
272            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
273            risingwave_common::types::DataType::Serial => {
274                Ok(starrocks_data_type.contains("bigint"))
275            }
276            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
277                "INT256 is not supported for Starrocks sink.".to_owned(),
278            )),
279            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
280                "MAP is not supported for Starrocks sink.".to_owned(),
281            )),
282            DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
283        }
284    }
285}
286
287impl Sink for StarrocksSink {
288    type Coordinator = DummySinkCommitCoordinator;
289    type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
290
291    const SINK_NAME: &'static str = STARROCKS_SINK;
292
293    async fn validate(&self) -> Result<()> {
294        if !self.is_append_only && self.pk_indices.is_empty() {
295            return Err(SinkError::Config(anyhow!(
296                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
297            )));
298        }
299        // check reachability
300        let mut client = StarrocksSchemaClient::new(
301            self.config.common.host.clone(),
302            self.config.common.mysql_port.clone(),
303            self.config.common.table.clone(),
304            self.config.common.database.clone(),
305            self.config.common.user.clone(),
306            self.config.common.password.clone(),
307        )
308        .await?;
309        let (read_model, pks) = client.get_pk_from_starrocks().await?;
310
311        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
312            return Err(SinkError::Config(anyhow!(
313                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
314            )));
315        }
316
317        for (index, filed) in self.schema.fields().iter().enumerate() {
318            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
319                return Err(SinkError::Starrocks(format!(
320                    "Can't find pk {:?} in starrocks",
321                    filed.name
322                )));
323            }
324        }
325
326        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
327
328        self.check_column_name_and_type(starrocks_columns_desc)?;
329        Ok(())
330    }
331
332    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
333        StarrocksConfig::from_btreemap(config.clone())?;
334        Ok(())
335    }
336
337    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
338        let commit_checkpoint_interval =
339            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
340                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
341            );
342
343        let writer = StarrocksSinkWriter::new(
344            self.config.clone(),
345            self.schema.clone(),
346            self.pk_indices.clone(),
347            self.is_append_only,
348            writer_param.executor_id,
349        )?;
350
351        let metrics = SinkWriterMetrics::new(&writer_param);
352
353        Ok(DecoupleCheckpointLogSinkerOf::new(
354            writer,
355            metrics,
356            commit_checkpoint_interval,
357        ))
358    }
359}
360
361pub struct StarrocksSinkWriter {
362    pub config: StarrocksConfig,
363    #[expect(dead_code)]
364    schema: Schema,
365    #[expect(dead_code)]
366    pk_indices: Vec<usize>,
367    is_append_only: bool,
368    client: Option<StarrocksClient>,
369    txn_client: Arc<StarrocksTxnClient>,
370    row_encoder: JsonEncoder,
371    executor_id: u64,
372    curr_txn_label: Option<String>,
373}
374
375impl TryFrom<SinkParam> for StarrocksSink {
376    type Error = SinkError;
377
378    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
379        let schema = param.schema();
380        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
381        StarrocksSink::new(param, config, schema)
382    }
383}
384
385impl StarrocksSinkWriter {
386    pub fn new(
387        config: StarrocksConfig,
388        schema: Schema,
389        pk_indices: Vec<usize>,
390        is_append_only: bool,
391        executor_id: u64,
392    ) -> Result<Self> {
393        let mut field_names = schema.names_str();
394        if !is_append_only {
395            field_names.push(STARROCKS_DELETE_SIGN);
396        };
397        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
398        // a field name being a reserved word. For example, `order`, 'from`, etc.
399        let field_names = field_names
400            .into_iter()
401            .map(|name| format!("`{}`", name))
402            .collect::<Vec<String>>();
403        let field_names_str = field_names
404            .iter()
405            .map(|name| name.as_str())
406            .collect::<Vec<&str>>();
407
408        let header = HeaderBuilder::new()
409            .add_common_header()
410            .set_user_password(config.common.user.clone(), config.common.password.clone())
411            .add_json_format()
412            .set_partial_update(config.partial_update.clone())
413            .set_columns_name(field_names_str)
414            .set_db(config.common.database.clone())
415            .set_table(config.common.table.clone())
416            .build();
417
418        let url = if config.common.use_https {
419            format!("https://{}:{}", config.common.host, config.common.http_port)
420        } else {
421            format!("http://{}:{}", config.common.host, config.common.http_port)
422        };
423        let txn_request_builder =
424            StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
425
426        Ok(Self {
427            config,
428            schema: schema.clone(),
429            pk_indices,
430            is_append_only,
431            client: None,
432            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
433            row_encoder: JsonEncoder::new_with_starrocks(schema, None),
434            executor_id,
435            curr_txn_label: None,
436        })
437    }
438
439    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
440        for (op, row) in chunk.rows() {
441            if op != Op::Insert {
442                continue;
443            }
444            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
445            self.client
446                .as_mut()
447                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
448                .write(row_json_string.into())
449                .await?;
450        }
451        Ok(())
452    }
453
454    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
455        for (op, row) in chunk.rows() {
456            match op {
457                Op::Insert => {
458                    let mut row_json_value = self.row_encoder.encode(row)?;
459                    row_json_value.insert(
460                        STARROCKS_DELETE_SIGN.to_owned(),
461                        Value::String("0".to_owned()),
462                    );
463                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
464                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
465                    })?;
466                    self.client
467                        .as_mut()
468                        .ok_or_else(|| {
469                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
470                        })?
471                        .write(row_json_string.into())
472                        .await?;
473                }
474                Op::Delete => {
475                    let mut row_json_value = self.row_encoder.encode(row)?;
476                    row_json_value.insert(
477                        STARROCKS_DELETE_SIGN.to_owned(),
478                        Value::String("1".to_owned()),
479                    );
480                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
481                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
482                    })?;
483                    self.client
484                        .as_mut()
485                        .ok_or_else(|| {
486                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
487                        })?
488                        .write(row_json_string.into())
489                        .await?;
490                }
491                Op::UpdateDelete => {}
492                Op::UpdateInsert => {
493                    let mut row_json_value = self.row_encoder.encode(row)?;
494                    row_json_value.insert(
495                        STARROCKS_DELETE_SIGN.to_owned(),
496                        Value::String("0".to_owned()),
497                    );
498                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
499                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
500                    })?;
501                    self.client
502                        .as_mut()
503                        .ok_or_else(|| {
504                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
505                        })?
506                        .write(row_json_string.into())
507                        .await?;
508                }
509            }
510        }
511        Ok(())
512    }
513
514    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
515    #[inline(always)]
516    fn new_txn_label(&self) -> String {
517        format!(
518            "rw-txn-{}-{}",
519            self.executor_id,
520            chrono::Utc::now().timestamp_micros()
521        )
522    }
523
524    async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
525        tracing::debug!(?txn_label, "prepare transaction");
526        let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
527        if txn_label != txn_label_res {
528            return Err(SinkError::Starrocks(format!(
529                "label {} returned from prepare transaction {} differs from the current one",
530                txn_label, txn_label_res
531            )));
532        }
533        tracing::debug!(?txn_label, "commit transaction");
534        let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
535        if txn_label != txn_label_res {
536            return Err(SinkError::Starrocks(format!(
537                "label {} returned from commit transaction {} differs from the current one",
538                txn_label, txn_label_res
539            )));
540        }
541        Ok(())
542    }
543}
544
545impl Drop for StarrocksSinkWriter {
546    fn drop(&mut self) {
547        if let Some(txn_label) = self.curr_txn_label.take() {
548            let txn_client = self.txn_client.clone();
549            tokio::spawn(async move {
550                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
551                    tracing::error!(
552                        "starrocks rollback transaction error: {:?}, txn label: {}",
553                        e.as_report(),
554                        txn_label
555                    );
556                }
557            });
558        }
559    }
560}
561
562#[async_trait]
563impl SinkWriter for StarrocksSinkWriter {
564    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
565        Ok(())
566    }
567
568    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
569        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
570        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
571        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
572        if self.curr_txn_label.is_none() {
573            let txn_label = self.new_txn_label();
574            tracing::debug!(?txn_label, "begin transaction");
575            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
576            if txn_label != txn_label_res {
577                return Err(SinkError::Starrocks(format!(
578                    "label {} returned from StarRocks {} differs from generated one",
579                    txn_label, txn_label_res
580                )));
581            }
582            self.curr_txn_label = Some(txn_label.clone());
583        }
584        if self.client.is_none() {
585            let txn_label = self.curr_txn_label.clone();
586            self.client = Some(StarrocksClient::new(
587                self.txn_client.load(txn_label.unwrap()).await?,
588            ));
589        }
590        if self.is_append_only {
591            self.append_only(chunk).await
592        } else {
593            self.upsert(chunk).await
594        }
595    }
596
597    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
598        if let Some(client) = self.client.take() {
599            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
600            // one or more load requests should be made within one commit_checkpoint_interval period.
601            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
602            // Thus, only one version will be produced when the transaction is committed. See Stream Load
603            // transaction interface for more information.
604            client.finish().await?;
605        }
606
607        if is_checkpoint
608            && let Some(txn_label) = self.curr_txn_label.take()
609            && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
610        {
611            match self.txn_client.rollback(txn_label.clone()).await {
612                Ok(_) => tracing::warn!(
613                    ?txn_label,
614                    "transaction is successfully rolled back due to commit failure"
615                ),
616                Err(err) => {
617                    tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
618                }
619            }
620
621            return Err(err);
622        }
623        Ok(())
624    }
625
626    async fn abort(&mut self) -> Result<()> {
627        if let Some(txn_label) = self.curr_txn_label.take() {
628            tracing::debug!(?txn_label, "rollback transaction");
629            self.txn_client.rollback(txn_label).await?;
630        }
631        Ok(())
632    }
633}
634
635pub struct StarrocksSchemaClient {
636    table: String,
637    db: String,
638    conn: mysql_async::Conn,
639}
640
641impl StarrocksSchemaClient {
642    pub async fn new(
643        host: String,
644        port: String,
645        table: String,
646        db: String,
647        user: String,
648        password: String,
649    ) -> Result<Self> {
650        // username & password may contain special chars, so we need to do URL encoding on them.
651        // Otherwise, Opts::from_url may report a `Parse error`
652        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
653        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
654
655        let conn_uri = format!(
656            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
657            user,
658            password,
659            host,
660            port,
661            db,
662            STARROCK_MYSQL_PREFER_SOCKET,
663            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
664            STARROCK_MYSQL_WAIT_TIMEOUT
665        );
666        let pool = mysql_async::Pool::new(
667            Opts::from_url(&conn_uri)
668                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
669        );
670        let conn = pool
671            .get_conn()
672            .await
673            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
674
675        Ok(Self { table, db, conn })
676    }
677
678    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
679        let query = format!(
680            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
681            self.table, self.db
682        );
683        let mut query_map: HashMap<String, String> = HashMap::default();
684        self.conn
685            .query_map(query, |(column_name, column_type)| {
686                query_map.insert(column_name, column_type)
687            })
688            .await
689            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
690        Ok(query_map)
691    }
692
693    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
694        let query = format!(
695            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
696            self.table, self.db
697        );
698        let table_mode_pk: (String, String) = self
699            .conn
700            .query_map(
701                query,
702                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
703                    .as_str()
704                {
705                    // Get primary key of aggregate table from the sort_key field
706                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
707                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
708                    "AGG_KEYS" => (table_model, sort_key),
709                    _ => (table_model, primary_key),
710                },
711            )
712            .await
713            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
714            .first()
715            .ok_or_else(|| {
716                SinkError::Starrocks(format!(
717                    "Can't find schema with table {:?} and database {:?}",
718                    self.table, self.db
719                ))
720            })?
721            .clone();
722        Ok(table_mode_pk)
723    }
724}
725
726#[derive(Debug, Serialize, Deserialize)]
727pub struct StarrocksInsertResultResponse {
728    #[serde(rename = "TxnId")]
729    pub txn_id: Option<i64>,
730    #[serde(rename = "Seq")]
731    pub seq: Option<i64>,
732    #[serde(rename = "Label")]
733    pub label: Option<String>,
734    #[serde(rename = "Status")]
735    pub status: String,
736    #[serde(rename = "Message")]
737    pub message: String,
738    #[serde(rename = "NumberTotalRows")]
739    pub number_total_rows: Option<i64>,
740    #[serde(rename = "NumberLoadedRows")]
741    pub number_loaded_rows: Option<i64>,
742    #[serde(rename = "NumberFilteredRows")]
743    pub number_filtered_rows: Option<i32>,
744    #[serde(rename = "NumberUnselectedRows")]
745    pub number_unselected_rows: Option<i32>,
746    #[serde(rename = "LoadBytes")]
747    pub load_bytes: Option<i64>,
748    #[serde(rename = "LoadTimeMs")]
749    pub load_time_ms: Option<i32>,
750    #[serde(rename = "BeginTxnTimeMs")]
751    pub begin_txn_time_ms: Option<i32>,
752    #[serde(rename = "ReadDataTimeMs")]
753    pub read_data_time_ms: Option<i32>,
754    #[serde(rename = "WriteDataTimeMs")]
755    pub write_data_time_ms: Option<i32>,
756    #[serde(rename = "CommitAndPublishTimeMs")]
757    pub commit_and_publish_time_ms: Option<i32>,
758    #[serde(rename = "StreamLoadPlanTimeMs")]
759    pub stream_load_plan_time_ms: Option<i32>,
760    #[serde(rename = "ExistingJobStatus")]
761    pub existing_job_status: Option<String>,
762    #[serde(rename = "ErrorURL")]
763    pub error_url: Option<String>,
764}
765
766pub struct StarrocksClient {
767    insert: InserterInner,
768}
769impl StarrocksClient {
770    pub fn new(insert: InserterInner) -> Self {
771        Self { insert }
772    }
773
774    pub async fn write(&mut self, data: Bytes) -> Result<()> {
775        self.insert.write(data).await?;
776        Ok(())
777    }
778
779    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
780        let raw = self.insert.finish().await?;
781        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
782            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
783
784        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
785            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
786                "Insert error: {}, {}, {:?}",
787                res.status,
788                res.message,
789                res.error_url,
790            )));
791        };
792        Ok(res)
793    }
794}
795
796pub struct StarrocksTxnClient {
797    request_builder: StarrocksTxnRequestBuilder,
798}
799
800impl StarrocksTxnClient {
801    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
802        Self { request_builder }
803    }
804
805    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
806        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
807            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
808        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
809            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
810                "transaction error: {}, {}, {:?}",
811                res.status,
812                res.message,
813                res.error_url,
814            )));
815        }
816        res.label.ok_or_else(|| {
817            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
818        })
819    }
820
821    pub async fn begin(&self, label: String) -> Result<String> {
822        let res = self
823            .request_builder
824            .build_begin_request_sender(label)?
825            .send()
826            .await?;
827        self.check_response_and_extract_label(res)
828    }
829
830    pub async fn prepare(&self, label: String) -> Result<String> {
831        let res = self
832            .request_builder
833            .build_prepare_request_sender(label)?
834            .send()
835            .await?;
836        self.check_response_and_extract_label(res)
837    }
838
839    pub async fn commit(&self, label: String) -> Result<String> {
840        let res = self
841            .request_builder
842            .build_commit_request_sender(label)?
843            .send()
844            .await?;
845        self.check_response_and_extract_label(res)
846    }
847
848    pub async fn rollback(&self, label: String) -> Result<String> {
849        let res = self
850            .request_builder
851            .build_rollback_request_sender(label)?
852            .send()
853            .await?;
854        self.check_response_and_extract_label(res)
855    }
856
857    pub async fn load(&self, label: String) -> Result<InserterInner> {
858        self.request_builder.build_txn_inserter(label).await
859    }
860}