risingwave_connector/sink/
starrocks.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use mysql_async::Opts;
23use mysql_async::prelude::Queryable;
24use risingwave_common::array::{Op, StreamChunk};
25use risingwave_common::catalog::Schema;
26use risingwave_common::types::DataType;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use serde_with::{DisplayFromStr, serde_as};
30use thiserror_ext::AsReport;
31use url::form_urlencoded;
32use with_options::WithOptions;
33
34use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
35use super::doris_starrocks_connector::{
36    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
37    StarrocksTxnRequestBuilder,
38};
39use super::encoder::{JsonEncoder, RowEncoder};
40use super::{
41    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
42    SinkWriterMetrics,
43};
44use crate::enforce_secret::EnforceSecret;
45use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
46use crate::sink::writer::SinkWriter;
47use crate::sink::{Result, Sink, SinkWriterParam};
48
49pub const STARROCKS_SINK: &str = "starrocks";
50const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
51const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
52const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
53
54pub const fn _default_stream_load_http_timeout_ms() -> u64 {
55    30 * 1000
56}
57
58const fn default_use_https() -> bool {
59    false
60}
61
62#[serde_as]
63#[derive(Deserialize, Debug, Clone, WithOptions)]
64pub struct StarrocksCommon {
65    /// The `StarRocks` host address.
66    #[serde(rename = "starrocks.host")]
67    pub host: String,
68    /// The port to the MySQL server of `StarRocks` FE.
69    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
70    pub mysql_port: String,
71    /// The port to the HTTP server of `StarRocks` FE.
72    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
73    pub http_port: String,
74    /// The user name used to access the `StarRocks` database.
75    #[serde(rename = "starrocks.user")]
76    pub user: String,
77    /// The password associated with the user.
78    #[serde(rename = "starrocks.password")]
79    pub password: String,
80    /// The `StarRocks` database where the target table is located
81    #[serde(rename = "starrocks.database")]
82    pub database: String,
83    /// The `StarRocks` table you want to sink data to.
84    #[serde(rename = "starrocks.table")]
85    pub table: String,
86
87    /// Whether to use https to connect to the `StarRocks` server.
88    #[serde(rename = "starrocks.use_https")]
89    #[serde(default = "default_use_https")]
90    #[serde_as(as = "DisplayFromStr")]
91    pub use_https: bool,
92}
93
94impl EnforceSecret for StarrocksCommon {
95    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
96        "starrocks.password", "starrocks.user"
97    };
98}
99
100#[serde_as]
101#[derive(Clone, Debug, Deserialize, WithOptions)]
102pub struct StarrocksConfig {
103    #[serde(flatten)]
104    pub common: StarrocksCommon,
105
106    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
107    #[serde(
108        rename = "starrocks.stream_load.http.timeout.ms",
109        default = "_default_stream_load_http_timeout_ms"
110    )]
111    #[serde_as(as = "DisplayFromStr")]
112    #[with_option(allow_alter_on_fly)]
113    pub stream_load_http_timeout_ms: u64,
114
115    /// Set this option to a positive integer n, RisingWave will try to commit data
116    /// to Starrocks at every n checkpoints by leveraging the
117    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
118    /// also, in this time, the `sink_decouple` option should be enabled as well.
119    /// Defaults to 10 if `commit_checkpoint_interval` <= 0
120    #[serde(default = "default_commit_checkpoint_interval")]
121    #[serde_as(as = "DisplayFromStr")]
122    #[with_option(allow_alter_on_fly)]
123    pub commit_checkpoint_interval: u64,
124
125    /// Enable partial update
126    #[serde(rename = "starrocks.partial_update")]
127    pub partial_update: Option<String>,
128
129    pub r#type: String, // accept "append-only" or "upsert"
130}
131
132impl EnforceSecret for StarrocksConfig {
133    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
134        StarrocksCommon::enforce_one(prop)
135    }
136}
137
138fn default_commit_checkpoint_interval() -> u64 {
139    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
140}
141
142impl StarrocksConfig {
143    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
144        let config =
145            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
146                .map_err(|e| SinkError::Config(anyhow!(e)))?;
147        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
148            return Err(SinkError::Config(anyhow!(
149                "`{}` must be {}, or {}",
150                SINK_TYPE_OPTION,
151                SINK_TYPE_APPEND_ONLY,
152                SINK_TYPE_UPSERT
153            )));
154        }
155        if config.commit_checkpoint_interval == 0 {
156            return Err(SinkError::Config(anyhow!(
157                "`commit_checkpoint_interval` must be greater than 0"
158            )));
159        }
160        Ok(config)
161    }
162}
163
164#[derive(Debug)]
165pub struct StarrocksSink {
166    pub config: StarrocksConfig,
167    schema: Schema,
168    pk_indices: Vec<usize>,
169    is_append_only: bool,
170}
171
172impl EnforceSecret for StarrocksSink {
173    fn enforce_secret<'a>(
174        prop_iter: impl Iterator<Item = &'a str>,
175    ) -> crate::error::ConnectorResult<()> {
176        for prop in prop_iter {
177            StarrocksConfig::enforce_one(prop)?;
178        }
179        Ok(())
180    }
181}
182
183impl StarrocksSink {
184    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
185        let pk_indices = param.downstream_pk_or_empty();
186        let is_append_only = param.sink_type.is_append_only();
187        Ok(Self {
188            config,
189            schema,
190            pk_indices,
191            is_append_only,
192        })
193    }
194}
195
196impl StarrocksSink {
197    fn check_column_name_and_type(
198        &self,
199        starrocks_columns_desc: HashMap<String, String>,
200    ) -> Result<()> {
201        let rw_fields_name = self.schema.fields();
202        if rw_fields_name.len() > starrocks_columns_desc.len() {
203            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
204        }
205
206        for i in rw_fields_name {
207            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
208                SinkError::Starrocks(format!(
209                    "Column name don't find in starrocks, risingwave is {:?} ",
210                    i.name
211                ))
212            })?;
213            if !Self::check_and_correct_column_type(&i.data_type, value)? {
214                return Err(SinkError::Starrocks(format!(
215                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
216                    i.name, value, i.data_type
217                )));
218            }
219        }
220        Ok(())
221    }
222
223    fn check_and_correct_column_type(
224        rw_data_type: &DataType,
225        starrocks_data_type: &String,
226    ) -> Result<bool> {
227        match rw_data_type {
228            risingwave_common::types::DataType::Boolean => {
229                Ok(starrocks_data_type.contains("tinyint") | starrocks_data_type.contains("boolean"))
230            }
231            risingwave_common::types::DataType::Int16 => {
232                Ok(starrocks_data_type.contains("smallint"))
233            }
234            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
235            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
236            risingwave_common::types::DataType::Float32 => {
237                Ok(starrocks_data_type.contains("float"))
238            }
239            risingwave_common::types::DataType::Float64 => {
240                Ok(starrocks_data_type.contains("double"))
241            }
242            risingwave_common::types::DataType::Decimal => {
243                Ok(starrocks_data_type.contains("decimal"))
244            }
245            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
246            risingwave_common::types::DataType::Varchar => {
247                Ok(starrocks_data_type.contains("varchar"))
248            }
249            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
250                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
251            )),
252            risingwave_common::types::DataType::Timestamp => {
253                Ok(starrocks_data_type.contains("datetime"))
254            }
255            risingwave_common::types::DataType::Timestamptz => Err(SinkError::Starrocks(
256                "TIMESTAMP WITH TIMEZONE is not supported for Starrocks sink as Starrocks doesn't store time values with timezone information. Please convert to TIMESTAMP first.".to_owned(),
257            )),
258            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
259                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
260            )),
261            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
262                "STRUCT is not supported for Starrocks sink.".to_owned(),
263            )),
264            risingwave_common::types::DataType::List(list) => {
265                // For compatibility with older versions starrocks
266                if starrocks_data_type.contains("unknown") {
267                    return Ok(true);
268                }
269                let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
270                Ok(check_result && starrocks_data_type.contains("array"))
271            }
272            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
273                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
274            )),
275            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
276            risingwave_common::types::DataType::Serial => {
277                Ok(starrocks_data_type.contains("bigint"))
278            }
279            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
280                "INT256 is not supported for Starrocks sink.".to_owned(),
281            )),
282            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
283                "MAP is not supported for Starrocks sink.".to_owned(),
284            )),
285            DataType::Vector(_) => Err(SinkError::Starrocks(
286                "VECTOR is not supported for Starrocks sink.".to_owned(),
287            )),
288        }
289    }
290}
291
292impl Sink for StarrocksSink {
293    type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
294
295    const SINK_NAME: &'static str = STARROCKS_SINK;
296
297    async fn validate(&self) -> Result<()> {
298        if !self.is_append_only && self.pk_indices.is_empty() {
299            return Err(SinkError::Config(anyhow!(
300                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
301            )));
302        }
303        // check reachability
304        let mut client = StarrocksSchemaClient::new(
305            self.config.common.host.clone(),
306            self.config.common.mysql_port.clone(),
307            self.config.common.table.clone(),
308            self.config.common.database.clone(),
309            self.config.common.user.clone(),
310            self.config.common.password.clone(),
311        )
312        .await?;
313        let (read_model, pks) = client.get_pk_from_starrocks().await?;
314
315        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
316            return Err(SinkError::Config(anyhow!(
317                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
318            )));
319        }
320
321        for (index, filed) in self.schema.fields().iter().enumerate() {
322            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
323                return Err(SinkError::Starrocks(format!(
324                    "Can't find pk {:?} in starrocks",
325                    filed.name
326                )));
327            }
328        }
329
330        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
331
332        self.check_column_name_and_type(starrocks_columns_desc)?;
333        Ok(())
334    }
335
336    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
337        StarrocksConfig::from_btreemap(config.clone())?;
338        Ok(())
339    }
340
341    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
342        let commit_checkpoint_interval =
343            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
344                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
345            );
346
347        let writer = StarrocksSinkWriter::new(
348            self.config.clone(),
349            self.schema.clone(),
350            self.pk_indices.clone(),
351            self.is_append_only,
352            writer_param.executor_id,
353        )?;
354
355        let metrics = SinkWriterMetrics::new(&writer_param);
356
357        Ok(DecoupleCheckpointLogSinkerOf::new(
358            writer,
359            metrics,
360            commit_checkpoint_interval,
361        ))
362    }
363}
364
365pub struct StarrocksSinkWriter {
366    pub config: StarrocksConfig,
367    #[expect(dead_code)]
368    schema: Schema,
369    #[expect(dead_code)]
370    pk_indices: Vec<usize>,
371    is_append_only: bool,
372    client: Option<StarrocksClient>,
373    txn_client: Arc<StarrocksTxnClient>,
374    row_encoder: JsonEncoder,
375    executor_id: u64,
376    curr_txn_label: Option<String>,
377}
378
379impl TryFrom<SinkParam> for StarrocksSink {
380    type Error = SinkError;
381
382    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
383        let schema = param.schema();
384        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
385        StarrocksSink::new(param, config, schema)
386    }
387}
388
389impl StarrocksSinkWriter {
390    pub fn new(
391        config: StarrocksConfig,
392        schema: Schema,
393        pk_indices: Vec<usize>,
394        is_append_only: bool,
395        executor_id: u64,
396    ) -> Result<Self> {
397        let mut field_names = schema.names_str();
398        if !is_append_only {
399            field_names.push(STARROCKS_DELETE_SIGN);
400        };
401        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
402        // a field name being a reserved word. For example, `order`, 'from`, etc.
403        let field_names = field_names
404            .into_iter()
405            .map(|name| format!("`{}`", name))
406            .collect::<Vec<String>>();
407        let field_names_str = field_names
408            .iter()
409            .map(|name| name.as_str())
410            .collect::<Vec<&str>>();
411
412        let header = HeaderBuilder::new()
413            .add_common_header()
414            .set_user_password(config.common.user.clone(), config.common.password.clone())
415            .add_json_format()
416            .set_partial_update(config.partial_update.clone())
417            .set_columns_name(field_names_str)
418            .set_db(config.common.database.clone())
419            .set_table(config.common.table.clone())
420            .build();
421
422        let url = if config.common.use_https {
423            format!("https://{}:{}", config.common.host, config.common.http_port)
424        } else {
425            format!("http://{}:{}", config.common.host, config.common.http_port)
426        };
427        let txn_request_builder =
428            StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
429
430        Ok(Self {
431            config,
432            schema: schema.clone(),
433            pk_indices,
434            is_append_only,
435            client: None,
436            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
437            row_encoder: JsonEncoder::new_with_starrocks(schema, None),
438            executor_id,
439            curr_txn_label: None,
440        })
441    }
442
443    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
444        for (op, row) in chunk.rows() {
445            if op != Op::Insert {
446                continue;
447            }
448            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
449            self.client
450                .as_mut()
451                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
452                .write(row_json_string.into())
453                .await?;
454        }
455        Ok(())
456    }
457
458    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
459        for (op, row) in chunk.rows() {
460            match op {
461                Op::Insert => {
462                    let mut row_json_value = self.row_encoder.encode(row)?;
463                    row_json_value.insert(
464                        STARROCKS_DELETE_SIGN.to_owned(),
465                        Value::String("0".to_owned()),
466                    );
467                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
468                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
469                    })?;
470                    self.client
471                        .as_mut()
472                        .ok_or_else(|| {
473                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
474                        })?
475                        .write(row_json_string.into())
476                        .await?;
477                }
478                Op::Delete => {
479                    let mut row_json_value = self.row_encoder.encode(row)?;
480                    row_json_value.insert(
481                        STARROCKS_DELETE_SIGN.to_owned(),
482                        Value::String("1".to_owned()),
483                    );
484                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
485                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
486                    })?;
487                    self.client
488                        .as_mut()
489                        .ok_or_else(|| {
490                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
491                        })?
492                        .write(row_json_string.into())
493                        .await?;
494                }
495                Op::UpdateDelete => {}
496                Op::UpdateInsert => {
497                    let mut row_json_value = self.row_encoder.encode(row)?;
498                    row_json_value.insert(
499                        STARROCKS_DELETE_SIGN.to_owned(),
500                        Value::String("0".to_owned()),
501                    );
502                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
503                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
504                    })?;
505                    self.client
506                        .as_mut()
507                        .ok_or_else(|| {
508                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
509                        })?
510                        .write(row_json_string.into())
511                        .await?;
512                }
513            }
514        }
515        Ok(())
516    }
517
518    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
519    #[inline(always)]
520    fn new_txn_label(&self) -> String {
521        format!(
522            "rw-txn-{}-{}",
523            self.executor_id,
524            chrono::Utc::now().timestamp_micros()
525        )
526    }
527
528    async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
529        tracing::debug!(?txn_label, "prepare transaction");
530        let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
531        if txn_label != txn_label_res {
532            return Err(SinkError::Starrocks(format!(
533                "label {} returned from prepare transaction {} differs from the current one",
534                txn_label, txn_label_res
535            )));
536        }
537        tracing::debug!(?txn_label, "commit transaction");
538        let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
539        if txn_label != txn_label_res {
540            return Err(SinkError::Starrocks(format!(
541                "label {} returned from commit transaction {} differs from the current one",
542                txn_label, txn_label_res
543            )));
544        }
545        Ok(())
546    }
547}
548
549impl Drop for StarrocksSinkWriter {
550    fn drop(&mut self) {
551        if let Some(txn_label) = self.curr_txn_label.take() {
552            let txn_client = self.txn_client.clone();
553            tokio::spawn(async move {
554                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
555                    tracing::error!(
556                        "starrocks rollback transaction error: {:?}, txn label: {}",
557                        e.as_report(),
558                        txn_label
559                    );
560                }
561            });
562        }
563    }
564}
565
566#[async_trait]
567impl SinkWriter for StarrocksSinkWriter {
568    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
569        Ok(())
570    }
571
572    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
573        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
574        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
575        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
576        if self.curr_txn_label.is_none() {
577            let txn_label = self.new_txn_label();
578            tracing::debug!(?txn_label, "begin transaction");
579            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
580            if txn_label != txn_label_res {
581                return Err(SinkError::Starrocks(format!(
582                    "label {} returned from StarRocks {} differs from generated one",
583                    txn_label, txn_label_res
584                )));
585            }
586            self.curr_txn_label = Some(txn_label.clone());
587        }
588        if self.client.is_none() {
589            let txn_label = self.curr_txn_label.clone();
590            self.client = Some(StarrocksClient::new(
591                self.txn_client.load(txn_label.unwrap()).await?,
592            ));
593        }
594        if self.is_append_only {
595            self.append_only(chunk).await
596        } else {
597            self.upsert(chunk).await
598        }
599    }
600
601    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
602        if let Some(client) = self.client.take() {
603            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
604            // one or more load requests should be made within one commit_checkpoint_interval period.
605            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
606            // Thus, only one version will be produced when the transaction is committed. See Stream Load
607            // transaction interface for more information.
608            client.finish().await?;
609        }
610
611        if is_checkpoint
612            && let Some(txn_label) = self.curr_txn_label.take()
613            && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
614        {
615            match self.txn_client.rollback(txn_label.clone()).await {
616                Ok(_) => tracing::warn!(
617                    ?txn_label,
618                    "transaction is successfully rolled back due to commit failure"
619                ),
620                Err(err) => {
621                    tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
622                }
623            }
624
625            return Err(err);
626        }
627        Ok(())
628    }
629
630    async fn abort(&mut self) -> Result<()> {
631        if let Some(txn_label) = self.curr_txn_label.take() {
632            tracing::debug!(?txn_label, "rollback transaction");
633            self.txn_client.rollback(txn_label).await?;
634        }
635        Ok(())
636    }
637}
638
639pub struct StarrocksSchemaClient {
640    table: String,
641    db: String,
642    conn: mysql_async::Conn,
643}
644
645impl StarrocksSchemaClient {
646    pub async fn new(
647        host: String,
648        port: String,
649        table: String,
650        db: String,
651        user: String,
652        password: String,
653    ) -> Result<Self> {
654        // username & password may contain special chars, so we need to do URL encoding on them.
655        // Otherwise, Opts::from_url may report a `Parse error`
656        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
657        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
658
659        let conn_uri = format!(
660            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
661            user,
662            password,
663            host,
664            port,
665            db,
666            STARROCK_MYSQL_PREFER_SOCKET,
667            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
668            STARROCK_MYSQL_WAIT_TIMEOUT
669        );
670        let pool = mysql_async::Pool::new(
671            Opts::from_url(&conn_uri)
672                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
673        );
674        let conn = pool
675            .get_conn()
676            .await
677            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
678
679        Ok(Self { table, db, conn })
680    }
681
682    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
683        let query = format!(
684            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
685            self.table, self.db
686        );
687        let mut query_map: HashMap<String, String> = HashMap::default();
688        self.conn
689            .query_map(query, |(column_name, column_type)| {
690                query_map.insert(column_name, column_type)
691            })
692            .await
693            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
694        Ok(query_map)
695    }
696
697    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
698        let query = format!(
699            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
700            self.table, self.db
701        );
702        let table_mode_pk: (String, String) = self
703            .conn
704            .query_map(
705                query,
706                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
707                    .as_str()
708                {
709                    // Get primary key of aggregate table from the sort_key field
710                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
711                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
712                    "AGG_KEYS" => (table_model, sort_key),
713                    _ => (table_model, primary_key),
714                },
715            )
716            .await
717            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
718            .first()
719            .ok_or_else(|| {
720                SinkError::Starrocks(format!(
721                    "Can't find schema with table {:?} and database {:?}",
722                    self.table, self.db
723                ))
724            })?
725            .clone();
726        Ok(table_mode_pk)
727    }
728}
729
730#[derive(Debug, Serialize, Deserialize)]
731pub struct StarrocksInsertResultResponse {
732    #[serde(rename = "TxnId")]
733    pub txn_id: Option<i64>,
734    #[serde(rename = "Seq")]
735    pub seq: Option<i64>,
736    #[serde(rename = "Label")]
737    pub label: Option<String>,
738    #[serde(rename = "Status")]
739    pub status: String,
740    #[serde(rename = "Message")]
741    pub message: String,
742    #[serde(rename = "NumberTotalRows")]
743    pub number_total_rows: Option<i64>,
744    #[serde(rename = "NumberLoadedRows")]
745    pub number_loaded_rows: Option<i64>,
746    #[serde(rename = "NumberFilteredRows")]
747    pub number_filtered_rows: Option<i32>,
748    #[serde(rename = "NumberUnselectedRows")]
749    pub number_unselected_rows: Option<i32>,
750    #[serde(rename = "LoadBytes")]
751    pub load_bytes: Option<i64>,
752    #[serde(rename = "LoadTimeMs")]
753    pub load_time_ms: Option<i32>,
754    #[serde(rename = "BeginTxnTimeMs")]
755    pub begin_txn_time_ms: Option<i32>,
756    #[serde(rename = "ReadDataTimeMs")]
757    pub read_data_time_ms: Option<i32>,
758    #[serde(rename = "WriteDataTimeMs")]
759    pub write_data_time_ms: Option<i32>,
760    #[serde(rename = "CommitAndPublishTimeMs")]
761    pub commit_and_publish_time_ms: Option<i32>,
762    #[serde(rename = "StreamLoadPlanTimeMs")]
763    pub stream_load_plan_time_ms: Option<i32>,
764    #[serde(rename = "ExistingJobStatus")]
765    pub existing_job_status: Option<String>,
766    #[serde(rename = "ErrorURL")]
767    pub error_url: Option<String>,
768}
769
770pub struct StarrocksClient {
771    insert: InserterInner,
772}
773impl StarrocksClient {
774    pub fn new(insert: InserterInner) -> Self {
775        Self { insert }
776    }
777
778    pub async fn write(&mut self, data: Bytes) -> Result<()> {
779        self.insert.write(data).await?;
780        Ok(())
781    }
782
783    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
784        let raw = self.insert.finish().await?;
785        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
786            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
787
788        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
789            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
790                "Insert error: {}, {}, {:?}",
791                res.status,
792                res.message,
793                res.error_url,
794            )));
795        };
796        Ok(res)
797    }
798}
799
800pub struct StarrocksTxnClient {
801    request_builder: StarrocksTxnRequestBuilder,
802}
803
804impl StarrocksTxnClient {
805    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
806        Self { request_builder }
807    }
808
809    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
810        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
811            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
812        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
813            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
814                "transaction error: {}, {}, {:?}",
815                res.status,
816                res.message,
817                res.error_url,
818            )));
819        }
820        res.label.ok_or_else(|| {
821            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
822        })
823    }
824
825    pub async fn begin(&self, label: String) -> Result<String> {
826        let res = self
827            .request_builder
828            .build_begin_request_sender(label)?
829            .send()
830            .await?;
831        self.check_response_and_extract_label(res)
832    }
833
834    pub async fn prepare(&self, label: String) -> Result<String> {
835        let res = self
836            .request_builder
837            .build_prepare_request_sender(label)?
838            .send()
839            .await?;
840        self.check_response_and_extract_label(res)
841    }
842
843    pub async fn commit(&self, label: String) -> Result<String> {
844        let res = self
845            .request_builder
846            .build_commit_request_sender(label)?
847            .send()
848            .await?;
849        self.check_response_and_extract_label(res)
850    }
851
852    pub async fn rollback(&self, label: String) -> Result<String> {
853        let res = self
854            .request_builder
855            .build_rollback_request_sender(label)?
856            .send()
857            .await?;
858        self.check_response_and_extract_label(res)
859    }
860
861    pub async fn load(&self, label: String) -> Result<InserterInner> {
862        self.request_builder.build_txn_inserter(label).await
863    }
864}