risingwave_connector/sink/
starrocks.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::num::NonZeroU64;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use async_trait::async_trait;
21use bytes::Bytes;
22use chrono_tz::Tz;
23use mysql_async::Opts;
24use mysql_async::prelude::Queryable;
25use risingwave_common::array::{Op, StreamChunk};
26use risingwave_common::catalog::Schema;
27use risingwave_common::types::DataType;
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30use serde_with::{DisplayFromStr, serde_as};
31use thiserror_ext::AsReport;
32use url::form_urlencoded;
33use uuid::Uuid;
34use with_options::WithOptions;
35
36use super::decouple_checkpoint_log_sink::DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE;
37use super::doris_starrocks_connector::{
38    HeaderBuilder, InserterInner, STARROCKS_DELETE_SIGN, STARROCKS_SUCCESS_STATUS,
39    StarrocksTxnRequestBuilder,
40};
41use super::encoder::{JsonEncoder, RowEncoder};
42use super::{
43    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkParam,
44    SinkWriterMetrics,
45};
46use crate::enforce_secret::EnforceSecret;
47use crate::sink::decouple_checkpoint_log_sink::DecoupleCheckpointLogSinkerOf;
48use crate::sink::writer::SinkWriter;
49use crate::sink::{Result, Sink, SinkWriterParam};
50
51pub const STARROCKS_SINK: &str = "starrocks";
52const STARROCK_MYSQL_PREFER_SOCKET: &str = "false";
53const STARROCK_MYSQL_MAX_ALLOWED_PACKET: usize = 1024;
54const STARROCK_MYSQL_WAIT_TIMEOUT: usize = 28800;
55
56pub const fn _default_stream_load_http_timeout_ms() -> u64 {
57    30 * 1000
58}
59
60const fn default_use_https() -> bool {
61    false
62}
63
64#[serde_as]
65#[derive(Deserialize, Debug, Clone, WithOptions)]
66pub struct StarrocksCommon {
67    /// The `StarRocks` host address.
68    #[serde(rename = "starrocks.host")]
69    pub host: String,
70    /// The port to the MySQL server of `StarRocks` FE.
71    #[serde(rename = "starrocks.mysqlport", alias = "starrocks.query_port")]
72    pub mysql_port: String,
73    /// The port to the HTTP server of `StarRocks` FE.
74    #[serde(rename = "starrocks.httpport", alias = "starrocks.http_port")]
75    pub http_port: String,
76    /// The user name used to access the `StarRocks` database.
77    #[serde(rename = "starrocks.user")]
78    pub user: String,
79    /// The password associated with the user.
80    #[serde(rename = "starrocks.password")]
81    pub password: String,
82    /// The `StarRocks` database where the target table is located
83    #[serde(rename = "starrocks.database")]
84    pub database: String,
85    /// The `StarRocks` table you want to sink data to.
86    #[serde(rename = "starrocks.table")]
87    pub table: String,
88
89    /// Whether to use https to connect to the `StarRocks` server.
90    #[serde(rename = "starrocks.use_https")]
91    #[serde(default = "default_use_https")]
92    #[serde_as(as = "DisplayFromStr")]
93    pub use_https: bool,
94}
95
96impl EnforceSecret for StarrocksCommon {
97    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf::phf_set! {
98        "starrocks.password", "starrocks.user"
99    };
100}
101
102#[serde_as]
103#[derive(Clone, Debug, Deserialize, WithOptions)]
104pub struct StarrocksConfig {
105    #[serde(flatten)]
106    pub common: StarrocksCommon,
107
108    /// The timeout in milliseconds for stream load http request, defaults to 10 seconds.
109    #[serde(
110        rename = "starrocks.stream_load.http.timeout.ms",
111        default = "_default_stream_load_http_timeout_ms"
112    )]
113    #[serde_as(as = "DisplayFromStr")]
114    #[with_option(allow_alter_on_fly)]
115    pub stream_load_http_timeout_ms: u64,
116
117    /// Set this option to a positive integer n, RisingWave will try to commit data
118    /// to Starrocks at every n checkpoints by leveraging the
119    /// [StreamLoad Transaction API](https://docs.starrocks.io/docs/loading/Stream_Load_transaction_interface/),
120    /// also, in this time, the `sink_decouple` option should be enabled as well.
121    /// Defaults to 10 if `commit_checkpoint_interval` <= 0
122    #[serde(default = "default_commit_checkpoint_interval")]
123    #[serde_as(as = "DisplayFromStr")]
124    #[with_option(allow_alter_on_fly)]
125    pub commit_checkpoint_interval: u64,
126
127    /// Enable partial update
128    #[serde(rename = "starrocks.partial_update")]
129    pub partial_update: Option<String>,
130
131    pub r#type: String, // accept "append-only" or "upsert"
132}
133
134impl EnforceSecret for StarrocksConfig {
135    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
136        StarrocksCommon::enforce_one(prop)
137    }
138}
139
140fn default_commit_checkpoint_interval() -> u64 {
141    DEFAULT_COMMIT_CHECKPOINT_INTERVAL_WITH_SINK_DECOUPLE
142}
143
144impl StarrocksConfig {
145    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
146        let config =
147            serde_json::from_value::<StarrocksConfig>(serde_json::to_value(properties).unwrap())
148                .map_err(|e| SinkError::Config(anyhow!(e)))?;
149        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
150            return Err(SinkError::Config(anyhow!(
151                "`{}` must be {}, or {}",
152                SINK_TYPE_OPTION,
153                SINK_TYPE_APPEND_ONLY,
154                SINK_TYPE_UPSERT
155            )));
156        }
157        if config.commit_checkpoint_interval == 0 {
158            return Err(SinkError::Config(anyhow!(
159                "`commit_checkpoint_interval` must be greater than 0"
160            )));
161        }
162        Ok(config)
163    }
164}
165
166#[derive(Debug)]
167pub struct StarrocksSink {
168    pub config: StarrocksConfig,
169    schema: Schema,
170    pk_indices: Vec<usize>,
171    is_append_only: bool,
172}
173
174impl EnforceSecret for StarrocksSink {
175    fn enforce_secret<'a>(
176        prop_iter: impl Iterator<Item = &'a str>,
177    ) -> crate::error::ConnectorResult<()> {
178        for prop in prop_iter {
179            StarrocksConfig::enforce_one(prop)?;
180        }
181        Ok(())
182    }
183}
184
185impl StarrocksSink {
186    pub fn new(param: SinkParam, config: StarrocksConfig, schema: Schema) -> Result<Self> {
187        let pk_indices = param.downstream_pk_or_empty();
188        let is_append_only = param.sink_type.is_append_only();
189        Ok(Self {
190            config,
191            schema,
192            pk_indices,
193            is_append_only,
194        })
195    }
196}
197
198impl StarrocksSink {
199    fn starrocks_data_type_contains_any(
200        starrocks_data_type: &str,
201        expected_types: &[&str],
202    ) -> bool {
203        expected_types
204            .iter()
205            .any(|expected_type| starrocks_data_type.contains(expected_type))
206    }
207
208    fn check_column_name_and_type(
209        &self,
210        starrocks_columns_desc: HashMap<String, String>,
211    ) -> Result<()> {
212        let rw_fields_name = self.schema.fields();
213        if rw_fields_name.len() > starrocks_columns_desc.len() {
214            return Err(SinkError::Starrocks("The columns of the sink must be equal to or a superset of the target table's columns.".to_owned()));
215        }
216
217        for i in rw_fields_name {
218            let value = starrocks_columns_desc.get(&i.name).ok_or_else(|| {
219                SinkError::Starrocks(format!(
220                    "Column name don't find in starrocks, risingwave is {:?} ",
221                    i.name
222                ))
223            })?;
224            if !Self::check_and_correct_column_type(&i.data_type, value)? {
225                return Err(SinkError::Starrocks(format!(
226                    "Column type don't match, column name is {:?}. starrocks type is {:?} risingwave type is {:?} ",
227                    i.name, value, i.data_type
228                )));
229            }
230        }
231        Ok(())
232    }
233
234    fn check_and_correct_column_type(
235        rw_data_type: &DataType,
236        starrocks_data_type: &str,
237    ) -> Result<bool> {
238        match rw_data_type {
239            risingwave_common::types::DataType::Boolean => {
240                Ok(Self::starrocks_data_type_contains_any(
241                    starrocks_data_type,
242                    &["tinyint", "boolean"],
243                ))
244            }
245            risingwave_common::types::DataType::Int16 => {
246                Ok(starrocks_data_type.contains("smallint"))
247            }
248            risingwave_common::types::DataType::Int32 => Ok(starrocks_data_type.contains("int")),
249            risingwave_common::types::DataType::Int64 => Ok(starrocks_data_type.contains("bigint")),
250            risingwave_common::types::DataType::Float32 => {
251                Ok(starrocks_data_type.contains("float"))
252            }
253            risingwave_common::types::DataType::Float64 => {
254                Ok(starrocks_data_type.contains("double"))
255            }
256            risingwave_common::types::DataType::Decimal => {
257                Ok(starrocks_data_type.contains("decimal"))
258            }
259            risingwave_common::types::DataType::Date => Ok(starrocks_data_type.contains("date")),
260            risingwave_common::types::DataType::Varchar => {
261                Ok(starrocks_data_type.contains("varchar"))
262            }
263            risingwave_common::types::DataType::Time => Err(SinkError::Starrocks(
264                "TIME is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
265            )),
266            risingwave_common::types::DataType::Timestamp => {
267                Ok(starrocks_data_type.contains("datetime"))
268            }
269            risingwave_common::types::DataType::Timestamptz => {
270                // StarRocks JSON encoding writes timestamptz as a timestamp string without a
271                // timezone suffix, so StarRocks can parse it into DATETIME or store it as text.
272                Ok(Self::starrocks_data_type_contains_any(
273                    starrocks_data_type,
274                    &["datetime", "varchar", "char", "string"],
275                ))
276            }
277            risingwave_common::types::DataType::Interval => Err(SinkError::Starrocks(
278                "INTERVAL is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
279            )),
280            risingwave_common::types::DataType::Struct(_) => Err(SinkError::Starrocks(
281                "STRUCT is not supported for Starrocks sink.".to_owned(),
282            )),
283            risingwave_common::types::DataType::List(list) => {
284                // For compatibility with older versions starrocks
285                if starrocks_data_type.contains("unknown") {
286                    return Ok(true);
287                }
288                let check_result = Self::check_and_correct_column_type(list.elem(), starrocks_data_type)?;
289                Ok(check_result && starrocks_data_type.contains("array"))
290            }
291            risingwave_common::types::DataType::Bytea => Err(SinkError::Starrocks(
292                "BYTEA is not supported for Starrocks sink. Please convert to VARCHAR or other supported types.".to_owned(),
293            )),
294            risingwave_common::types::DataType::Jsonb => Ok(starrocks_data_type.contains("json")),
295            risingwave_common::types::DataType::Serial => {
296                Ok(starrocks_data_type.contains("bigint"))
297            }
298            risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks(
299                "INT256 is not supported for Starrocks sink.".to_owned(),
300            )),
301            risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks(
302                "MAP is not supported for Starrocks sink.".to_owned(),
303            )),
304            DataType::Vector(_) => Err(SinkError::Starrocks(
305                "VECTOR is not supported for Starrocks sink.".to_owned(),
306            )),
307        }
308    }
309}
310
311impl Sink for StarrocksSink {
312    type LogSinker = DecoupleCheckpointLogSinkerOf<StarrocksSinkWriter>;
313
314    const SINK_NAME: &'static str = STARROCKS_SINK;
315
316    async fn validate(&self) -> Result<()> {
317        if !self.is_append_only && self.pk_indices.is_empty() {
318            return Err(SinkError::Config(anyhow!(
319                "Primary key not defined for upsert starrocks sink (please define in `primary_key` field)"
320            )));
321        }
322        // check reachability
323        let mut client = StarrocksSchemaClient::new(
324            self.config.common.host.clone(),
325            self.config.common.mysql_port.clone(),
326            self.config.common.table.clone(),
327            self.config.common.database.clone(),
328            self.config.common.user.clone(),
329            self.config.common.password.clone(),
330        )
331        .await?;
332        let (read_model, pks) = client.get_pk_from_starrocks().await?;
333
334        if !self.is_append_only && read_model.ne("PRIMARY_KEYS") {
335            return Err(SinkError::Config(anyhow!(
336                "If you want to use upsert, please set the keysType of starrocks to PRIMARY_KEY"
337            )));
338        }
339
340        for (index, filed) in self.schema.fields().iter().enumerate() {
341            if self.pk_indices.contains(&index) && !pks.contains(&filed.name) {
342                return Err(SinkError::Starrocks(format!(
343                    "Can't find pk {:?} in starrocks",
344                    filed.name
345                )));
346            }
347        }
348
349        let starrocks_columns_desc = client.get_columns_from_starrocks().await?;
350
351        self.check_column_name_and_type(starrocks_columns_desc)?;
352        Ok(())
353    }
354
355    fn validate_alter_config(config: &BTreeMap<String, String>) -> Result<()> {
356        StarrocksConfig::from_btreemap(config.clone())?;
357        Ok(())
358    }
359
360    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
361        let commit_checkpoint_interval =
362            NonZeroU64::new(self.config.commit_checkpoint_interval).expect(
363                "commit_checkpoint_interval should be greater than 0, and it should be checked in config validation",
364            );
365
366        let writer = StarrocksSinkWriter::new(
367            self.config.clone(),
368            self.schema.clone(),
369            self.pk_indices.clone(),
370            self.is_append_only,
371            writer_param.time_zone,
372        )?;
373
374        let metrics = SinkWriterMetrics::new(&writer_param);
375
376        Ok(DecoupleCheckpointLogSinkerOf::new(
377            writer,
378            metrics,
379            commit_checkpoint_interval,
380        ))
381    }
382}
383
384pub struct StarrocksSinkWriter {
385    pub config: StarrocksConfig,
386    #[expect(dead_code)]
387    schema: Schema,
388    #[expect(dead_code)]
389    pk_indices: Vec<usize>,
390    is_append_only: bool,
391    client: Option<StarrocksClient>,
392    txn_client: Arc<StarrocksTxnClient>,
393    row_encoder: JsonEncoder,
394    curr_txn_label: Option<String>,
395}
396
397impl TryFrom<SinkParam> for StarrocksSink {
398    type Error = SinkError;
399
400    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
401        let schema = param.schema();
402        let config = StarrocksConfig::from_btreemap(param.properties.clone())?;
403        StarrocksSink::new(param, config, schema)
404    }
405}
406
407impl StarrocksSinkWriter {
408    pub fn new(
409        config: StarrocksConfig,
410        schema: Schema,
411        pk_indices: Vec<usize>,
412        is_append_only: bool,
413        time_zone: Tz,
414    ) -> Result<Self> {
415        let mut field_names = schema.names_str();
416        if !is_append_only {
417            field_names.push(STARROCKS_DELETE_SIGN);
418        };
419        // we should quote field names in `MySQL` style to prevent `StarRocks` from rejecting the request due to
420        // a field name being a reserved word. For example, `order`, 'from`, etc.
421        let field_names = field_names
422            .into_iter()
423            .map(|name| format!("`{}`", name))
424            .collect::<Vec<String>>();
425        let field_names_str = field_names
426            .iter()
427            .map(|name| name.as_str())
428            .collect::<Vec<&str>>();
429
430        let header = HeaderBuilder::new()
431            .add_common_header()
432            .set_user_password(config.common.user.clone(), config.common.password.clone())
433            .add_json_format()
434            .set_partial_update(config.partial_update.clone())
435            .set_columns_name(field_names_str)
436            .set_db(config.common.database.clone())
437            .set_table(config.common.table.clone())
438            .build();
439
440        let url = if config.common.use_https {
441            format!("https://{}:{}", config.common.host, config.common.http_port)
442        } else {
443            format!("http://{}:{}", config.common.host, config.common.http_port)
444        };
445        let txn_request_builder =
446            StarrocksTxnRequestBuilder::new(url, header, config.stream_load_http_timeout_ms)?;
447
448        Ok(Self {
449            config,
450            schema: schema.clone(),
451            pk_indices,
452            is_append_only,
453            client: None,
454            txn_client: Arc::new(StarrocksTxnClient::new(txn_request_builder)),
455            row_encoder: JsonEncoder::new_with_starrocks(schema, None, time_zone),
456            curr_txn_label: None,
457        })
458    }
459
460    async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> {
461        for (op, row) in chunk.rows() {
462            if op != Op::Insert {
463                continue;
464            }
465            let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string();
466            self.client
467                .as_mut()
468                .ok_or_else(|| SinkError::Starrocks("Can't find starrocks sink insert".to_owned()))?
469                .write(row_json_string.into())
470                .await?;
471        }
472        Ok(())
473    }
474
475    async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> {
476        for (op, row) in chunk.rows() {
477            match op {
478                Op::Insert => {
479                    let mut row_json_value = self.row_encoder.encode(row)?;
480                    row_json_value.insert(
481                        STARROCKS_DELETE_SIGN.to_owned(),
482                        Value::String("0".to_owned()),
483                    );
484                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
485                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
486                    })?;
487                    self.client
488                        .as_mut()
489                        .ok_or_else(|| {
490                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
491                        })?
492                        .write(row_json_string.into())
493                        .await?;
494                }
495                Op::Delete => {
496                    let mut row_json_value = self.row_encoder.encode(row)?;
497                    row_json_value.insert(
498                        STARROCKS_DELETE_SIGN.to_owned(),
499                        Value::String("1".to_owned()),
500                    );
501                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
502                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
503                    })?;
504                    self.client
505                        .as_mut()
506                        .ok_or_else(|| {
507                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
508                        })?
509                        .write(row_json_string.into())
510                        .await?;
511                }
512                Op::UpdateDelete => {}
513                Op::UpdateInsert => {
514                    let mut row_json_value = self.row_encoder.encode(row)?;
515                    row_json_value.insert(
516                        STARROCKS_DELETE_SIGN.to_owned(),
517                        Value::String("0".to_owned()),
518                    );
519                    let row_json_string = serde_json::to_string(&row_json_value).map_err(|e| {
520                        SinkError::Starrocks(format!("Json derialize error: {}", e.as_report()))
521                    })?;
522                    self.client
523                        .as_mut()
524                        .ok_or_else(|| {
525                            SinkError::Starrocks("Can't find starrocks sink insert".to_owned())
526                        })?
527                        .write(row_json_string.into())
528                        .await?;
529                }
530            }
531        }
532        Ok(())
533    }
534
535    /// Generating a new transaction label, should be unique across all `SinkWriters` even under rewinding.
536    #[inline(always)]
537    fn new_txn_label(&self) -> String {
538        format!(
539            "rw-txn-{}-{}",
540            Uuid::new_v4(),
541            chrono::Utc::now().timestamp_micros()
542        )
543    }
544
545    async fn prepare_and_commit(&self, txn_label: String) -> Result<()> {
546        tracing::debug!(?txn_label, "prepare transaction");
547        let txn_label_res = self.txn_client.prepare(txn_label.clone()).await?;
548        if txn_label != txn_label_res {
549            return Err(SinkError::Starrocks(format!(
550                "label {} returned from prepare transaction {} differs from the current one",
551                txn_label, txn_label_res
552            )));
553        }
554        tracing::debug!(?txn_label, "commit transaction");
555        let txn_label_res = self.txn_client.commit(txn_label.clone()).await?;
556        if txn_label != txn_label_res {
557            return Err(SinkError::Starrocks(format!(
558                "label {} returned from commit transaction {} differs from the current one",
559                txn_label, txn_label_res
560            )));
561        }
562        Ok(())
563    }
564}
565
566impl Drop for StarrocksSinkWriter {
567    fn drop(&mut self) {
568        if let Some(txn_label) = self.curr_txn_label.take() {
569            let txn_client = self.txn_client.clone();
570            tokio::spawn(async move {
571                if let Err(e) = txn_client.rollback(txn_label.clone()).await {
572                    tracing::error!(
573                        "starrocks rollback transaction error: {:?}, txn label: {}",
574                        e.as_report(),
575                        txn_label
576                    );
577                }
578            });
579        }
580    }
581}
582
583#[async_trait]
584impl SinkWriter for StarrocksSinkWriter {
585    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
586        Ok(())
587    }
588
589    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
590        // We check whether start a new transaction in `write_batch`. Therefore, if no data has been written
591        // within the `commit_checkpoint_interval` period, no meta requests will be made. Otherwise if we request
592        // `prepare` against an empty transaction, the `StarRocks` will report a `hasn't send any data yet` error.
593        if self.curr_txn_label.is_none() {
594            let txn_label = self.new_txn_label();
595            tracing::debug!(?txn_label, "begin transaction");
596            let txn_label_res = self.txn_client.begin(txn_label.clone()).await?;
597            if txn_label != txn_label_res {
598                return Err(SinkError::Starrocks(format!(
599                    "label {} returned from StarRocks {} differs from generated one",
600                    txn_label, txn_label_res
601                )));
602            }
603            self.curr_txn_label = Some(txn_label.clone());
604        }
605        if self.client.is_none() {
606            let txn_label = self.curr_txn_label.clone();
607            self.client = Some(StarrocksClient::new(
608                self.txn_client.load(txn_label.unwrap()).await?,
609            ));
610        }
611        if self.is_append_only {
612            self.append_only(chunk).await
613        } else {
614            self.upsert(chunk).await
615        }
616    }
617
618    async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> {
619        if let Some(client) = self.client.take() {
620            // Here we finish the `/api/transaction/load` request when a barrier is received. Therefore,
621            // one or more load requests should be made within one commit_checkpoint_interval period.
622            // StarRocks will take care of merging those splits into a larger one during prepare transaction.
623            // Thus, only one version will be produced when the transaction is committed. See Stream Load
624            // transaction interface for more information.
625            client.finish().await?;
626        }
627
628        if is_checkpoint
629            && let Some(txn_label) = self.curr_txn_label.take()
630            && let Err(err) = self.prepare_and_commit(txn_label.clone()).await
631        {
632            match self.txn_client.rollback(txn_label.clone()).await {
633                Ok(_) => tracing::warn!(
634                    ?txn_label,
635                    "transaction is successfully rolled back due to commit failure"
636                ),
637                Err(err) => {
638                    tracing::warn!(?txn_label, error = ?err.as_report(), "Couldn't roll back transaction after commit failed")
639                }
640            }
641
642            return Err(err);
643        }
644        Ok(())
645    }
646
647    async fn abort(&mut self) -> Result<()> {
648        if let Some(txn_label) = self.curr_txn_label.take() {
649            tracing::debug!(?txn_label, "rollback transaction");
650            self.txn_client.rollback(txn_label).await?;
651        }
652        Ok(())
653    }
654}
655
656pub struct StarrocksSchemaClient {
657    table: String,
658    db: String,
659    conn: mysql_async::Conn,
660}
661
662impl StarrocksSchemaClient {
663    pub async fn new(
664        host: String,
665        port: String,
666        table: String,
667        db: String,
668        user: String,
669        password: String,
670    ) -> Result<Self> {
671        // username & password may contain special chars, so we need to do URL encoding on them.
672        // Otherwise, Opts::from_url may report a `Parse error`
673        let user = form_urlencoded::byte_serialize(user.as_bytes()).collect::<String>();
674        let password = form_urlencoded::byte_serialize(password.as_bytes()).collect::<String>();
675
676        let conn_uri = format!(
677            "mysql://{}:{}@{}:{}/{}?prefer_socket={}&max_allowed_packet={}&wait_timeout={}",
678            user,
679            password,
680            host,
681            port,
682            db,
683            STARROCK_MYSQL_PREFER_SOCKET,
684            STARROCK_MYSQL_MAX_ALLOWED_PACKET,
685            STARROCK_MYSQL_WAIT_TIMEOUT
686        );
687        let pool = mysql_async::Pool::new(
688            Opts::from_url(&conn_uri)
689                .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?,
690        );
691        let conn = pool
692            .get_conn()
693            .await
694            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
695
696        Ok(Self { table, db, conn })
697    }
698
699    pub async fn get_columns_from_starrocks(&mut self) -> Result<HashMap<String, String>> {
700        let query = format!(
701            "select column_name, column_type from information_schema.columns where table_name = {:?} and table_schema = {:?};",
702            self.table, self.db
703        );
704        let mut query_map: HashMap<String, String> = HashMap::default();
705        self.conn
706            .query_map(query, |(column_name, column_type)| {
707                query_map.insert(column_name, column_type)
708            })
709            .await
710            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
711        Ok(query_map)
712    }
713
714    pub async fn get_pk_from_starrocks(&mut self) -> Result<(String, String)> {
715        let query = format!(
716            "select table_model, primary_key, sort_key from information_schema.tables_config where table_name = {:?} and table_schema = {:?};",
717            self.table, self.db
718        );
719        let table_mode_pk: (String, String) = self
720            .conn
721            .query_map(
722                query,
723                |(table_model, primary_key, sort_key): (String, String, String)| match table_model
724                    .as_str()
725                {
726                    // Get primary key of aggregate table from the sort_key field
727                    // https://docs.starrocks.io/docs/table_design/table_types/table_capabilities/
728                    // https://docs.starrocks.io/docs/sql-reference/information_schema/tables_config/
729                    "AGG_KEYS" => (table_model, sort_key),
730                    _ => (table_model, primary_key),
731                },
732            )
733            .await
734            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?
735            .first()
736            .ok_or_else(|| {
737                SinkError::Starrocks(format!(
738                    "Can't find schema for StarRocks table {:?} in database {:?}. Please check that the table exists and the StarRocks user has write permission on the target table.",
739                    self.table, self.db
740                ))
741            })?
742            .clone();
743        Ok(table_mode_pk)
744    }
745}
746
747#[derive(Debug, Serialize, Deserialize)]
748pub struct StarrocksInsertResultResponse {
749    #[serde(rename = "TxnId")]
750    pub txn_id: Option<i64>,
751    #[serde(rename = "Seq")]
752    pub seq: Option<i64>,
753    #[serde(rename = "Label")]
754    pub label: Option<String>,
755    #[serde(rename = "Status")]
756    pub status: String,
757    #[serde(rename = "Message")]
758    pub message: String,
759    #[serde(rename = "NumberTotalRows")]
760    pub number_total_rows: Option<i64>,
761    #[serde(rename = "NumberLoadedRows")]
762    pub number_loaded_rows: Option<i64>,
763    #[serde(rename = "NumberFilteredRows")]
764    pub number_filtered_rows: Option<i32>,
765    #[serde(rename = "NumberUnselectedRows")]
766    pub number_unselected_rows: Option<i32>,
767    #[serde(rename = "LoadBytes")]
768    pub load_bytes: Option<i64>,
769    #[serde(rename = "LoadTimeMs")]
770    pub load_time_ms: Option<i32>,
771    #[serde(rename = "BeginTxnTimeMs")]
772    pub begin_txn_time_ms: Option<i32>,
773    #[serde(rename = "ReadDataTimeMs")]
774    pub read_data_time_ms: Option<i32>,
775    #[serde(rename = "WriteDataTimeMs")]
776    pub write_data_time_ms: Option<i32>,
777    #[serde(rename = "CommitAndPublishTimeMs")]
778    pub commit_and_publish_time_ms: Option<i32>,
779    #[serde(rename = "StreamLoadPlanTimeMs")]
780    pub stream_load_plan_time_ms: Option<i32>,
781    #[serde(rename = "ExistingJobStatus")]
782    pub existing_job_status: Option<String>,
783    #[serde(rename = "ErrorURL")]
784    pub error_url: Option<String>,
785}
786
787pub struct StarrocksClient {
788    insert: InserterInner,
789}
790impl StarrocksClient {
791    pub fn new(insert: InserterInner) -> Self {
792        Self { insert }
793    }
794
795    pub async fn write(&mut self, data: Bytes) -> Result<()> {
796        self.insert.write(data).await?;
797        Ok(())
798    }
799
800    pub async fn finish(self) -> Result<StarrocksInsertResultResponse> {
801        let raw = self.insert.finish().await?;
802        let res: StarrocksInsertResultResponse = serde_json::from_slice(&raw)
803            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
804
805        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
806            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
807                "Insert error: {}, {}, {:?}",
808                res.status,
809                res.message,
810                res.error_url,
811            )));
812        };
813        Ok(res)
814    }
815}
816
817pub struct StarrocksTxnClient {
818    request_builder: StarrocksTxnRequestBuilder,
819}
820
821impl StarrocksTxnClient {
822    pub fn new(request_builder: StarrocksTxnRequestBuilder) -> Self {
823        Self { request_builder }
824    }
825
826    fn check_response_and_extract_label(&self, res: Bytes) -> Result<String> {
827        let res: StarrocksInsertResultResponse = serde_json::from_slice(&res)
828            .map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;
829        if !STARROCKS_SUCCESS_STATUS.contains(&res.status.as_str()) {
830            return Err(SinkError::DorisStarrocksConnect(anyhow::anyhow!(
831                "transaction error: {}, {}, {:?}",
832                res.status,
833                res.message,
834                res.error_url,
835            )));
836        }
837        res.label.ok_or_else(|| {
838            SinkError::DorisStarrocksConnect(anyhow::anyhow!("Can't get label from response"))
839        })
840    }
841
842    pub async fn begin(&self, label: String) -> Result<String> {
843        let res = self
844            .request_builder
845            .build_begin_request_sender(label)?
846            .send()
847            .await?;
848        self.check_response_and_extract_label(res)
849    }
850
851    pub async fn prepare(&self, label: String) -> Result<String> {
852        let res = self
853            .request_builder
854            .build_prepare_request_sender(label)?
855            .send()
856            .await?;
857        self.check_response_and_extract_label(res)
858    }
859
860    pub async fn commit(&self, label: String) -> Result<String> {
861        let res = self
862            .request_builder
863            .build_commit_request_sender(label)?
864            .send()
865            .await?;
866        self.check_response_and_extract_label(res)
867    }
868
869    pub async fn rollback(&self, label: String) -> Result<String> {
870        let res = self
871            .request_builder
872            .build_rollback_request_sender(label)?
873            .send()
874            .await?;
875        self.check_response_and_extract_label(res)
876    }
877
878    pub async fn load(&self, label: String) -> Result<InserterInner> {
879        self.request_builder.build_txn_inserter(label).await
880    }
881}
882
883#[cfg(test)]
884mod tests {
885    use risingwave_common::types::DataType;
886
887    use super::*;
888
889    fn is_compatible(rw_data_type: DataType, starrocks_data_type: &str) -> bool {
890        StarrocksSink::check_and_correct_column_type(&rw_data_type, starrocks_data_type).unwrap()
891    }
892
893    #[test]
894    fn test_timestamptz_compatible_starrocks_types() {
895        for starrocks_data_type in [
896            "datetime",
897            "datetime(6)",
898            "varchar(64)",
899            "char(32)",
900            "string",
901        ] {
902            assert!(
903                is_compatible(DataType::Timestamptz, starrocks_data_type),
904                "{starrocks_data_type} should be compatible with timestamptz"
905            );
906        }
907    }
908
909    #[test]
910    fn test_timestamptz_incompatible_starrocks_types() {
911        for starrocks_data_type in ["date", "int", "bigint", "json", "boolean"] {
912            assert!(
913                !is_compatible(DataType::Timestamptz, starrocks_data_type),
914                "{starrocks_data_type} should not be compatible with timestamptz"
915            );
916        }
917    }
918
919}