risingwave_connector/sink/
sqlserver.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use risingwave_common::array::{Op, RowRef, StreamChunk};
20use risingwave_common::catalog::Schema;
21use risingwave_common::row::{OwnedRow, Row};
22use risingwave_common::types::{DataType, Decimal};
23use serde_derive::Deserialize;
24use serde_with::{DisplayFromStr, serde_as};
25use simd_json::prelude::ArrayTrait;
26use tiberius::numeric::Numeric;
27use tiberius::{AuthMethod, Client, ColumnData, Config, Query};
28use tokio::net::TcpStream;
29use tokio_util::compat::TokioAsyncWriteCompatExt;
30use with_options::WithOptions;
31
32use super::{
33    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkWriterMetrics,
34};
35use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
36use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
37
38pub const SQLSERVER_SINK: &str = "sqlserver";
39
40fn default_max_batch_rows() -> usize {
41    1024
42}
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize, WithOptions)]
46pub struct SqlServerConfig {
47    #[serde(rename = "sqlserver.host")]
48    pub host: String,
49    #[serde(rename = "sqlserver.port")]
50    #[serde_as(as = "DisplayFromStr")]
51    pub port: u16,
52    #[serde(rename = "sqlserver.user")]
53    pub user: String,
54    #[serde(rename = "sqlserver.password")]
55    pub password: String,
56    #[serde(rename = "sqlserver.database")]
57    pub database: String,
58    #[serde(rename = "sqlserver.table")]
59    pub table: String,
60    #[serde(
61        rename = "sqlserver.max_batch_rows",
62        default = "default_max_batch_rows"
63    )]
64    #[serde_as(as = "DisplayFromStr")]
65    pub max_batch_rows: usize,
66    pub r#type: String, // accept "append-only" or "upsert"
67}
68
69impl SqlServerConfig {
70    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
71        let config =
72            serde_json::from_value::<SqlServerConfig>(serde_json::to_value(properties).unwrap())
73                .map_err(|e| SinkError::Config(anyhow!(e)))?;
74        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
75            return Err(SinkError::Config(anyhow!(
76                "`{}` must be {}, or {}",
77                SINK_TYPE_OPTION,
78                SINK_TYPE_APPEND_ONLY,
79                SINK_TYPE_UPSERT
80            )));
81        }
82        Ok(config)
83    }
84}
85
86#[derive(Debug)]
87pub struct SqlServerSink {
88    pub config: SqlServerConfig,
89    schema: Schema,
90    pk_indices: Vec<usize>,
91    is_append_only: bool,
92}
93
94impl SqlServerSink {
95    pub fn new(
96        mut config: SqlServerConfig,
97        schema: Schema,
98        pk_indices: Vec<usize>,
99        is_append_only: bool,
100    ) -> Result<Self> {
101        // Rewrite config because tiberius allows a maximum of 2100 params in one query request.
102        const TIBERIUS_PARAM_MAX: usize = 2000;
103        let params_per_op = schema.fields().len();
104        let tiberius_max_batch_rows = if params_per_op == 0 {
105            config.max_batch_rows
106        } else {
107            ((TIBERIUS_PARAM_MAX as f64 / params_per_op as f64).floor()) as usize
108        };
109        if tiberius_max_batch_rows == 0 {
110            return Err(SinkError::SqlServer(anyhow!(format!(
111                "too many column {}",
112                params_per_op
113            ))));
114        }
115        config.max_batch_rows = std::cmp::min(config.max_batch_rows, tiberius_max_batch_rows);
116        Ok(Self {
117            config,
118            schema,
119            pk_indices,
120            is_append_only,
121        })
122    }
123}
124
125impl TryFrom<SinkParam> for SqlServerSink {
126    type Error = SinkError;
127
128    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
129        let schema = param.schema();
130        let config = SqlServerConfig::from_btreemap(param.properties)?;
131        SqlServerSink::new(
132            config,
133            schema,
134            param.downstream_pk,
135            param.sink_type.is_append_only(),
136        )
137    }
138}
139
140impl Sink for SqlServerSink {
141    type Coordinator = DummySinkCommitCoordinator;
142    type LogSinker = LogSinkerOf<SqlServerSinkWriter>;
143
144    const SINK_NAME: &'static str = SQLSERVER_SINK;
145
146    async fn validate(&self) -> Result<()> {
147        risingwave_common::license::Feature::SqlServerSink
148            .check_available()
149            .map_err(|e| anyhow::anyhow!(e))?;
150
151        if !self.is_append_only && self.pk_indices.is_empty() {
152            return Err(SinkError::Config(anyhow!(
153                "Primary key not defined for upsert SQL Server sink (please define in `primary_key` field)"
154            )));
155        }
156
157        for f in self.schema.fields() {
158            check_data_type_compatibility(&f.data_type)?;
159        }
160
161        // Query table metadata from SQL Server.
162        let mut sql_server_table_metadata = HashMap::new();
163        let mut sql_client = SqlServerClient::new(&self.config).await?;
164        let query_table_metadata_error = || {
165            SinkError::SqlServer(anyhow!(format!(
166                "SQL Server table {} metadata error",
167                self.config.table
168            )))
169        };
170        static QUERY_TABLE_METADATA: &str = r#"
171SELECT
172    col.name AS ColumnName,
173    pk.index_id AS PkIndex
174FROM
175    sys.columns col
176LEFT JOIN
177    sys.index_columns ic ON ic.object_id = col.object_id AND ic.column_id = col.column_id
178LEFT JOIN
179    sys.indexes pk ON pk.object_id = col.object_id AND pk.index_id = ic.index_id AND pk.is_primary_key = 1
180WHERE
181    col.object_id = OBJECT_ID(@P1)
182ORDER BY
183    col.column_id;"#;
184        let rows = sql_client
185            .inner_client
186            .query(QUERY_TABLE_METADATA, &[&self.config.table])
187            .await?
188            .into_results()
189            .await?;
190        for row in rows.into_iter().flatten() {
191            let mut iter = row.into_iter();
192            let ColumnData::String(Some(col_name)) =
193                iter.next().ok_or_else(query_table_metadata_error)?
194            else {
195                return Err(query_table_metadata_error());
196            };
197            let ColumnData::I32(col_pk_index) =
198                iter.next().ok_or_else(query_table_metadata_error)?
199            else {
200                return Err(query_table_metadata_error());
201            };
202            sql_server_table_metadata.insert(col_name.into_owned(), col_pk_index.is_some());
203        }
204
205        // Validate Column name and Primary Key
206        for (idx, col) in self.schema.fields().iter().enumerate() {
207            let rw_is_pk = self.pk_indices.contains(&idx);
208            match sql_server_table_metadata.get(&col.name) {
209                None => {
210                    return Err(SinkError::SqlServer(anyhow!(format!(
211                        "column {} not found in the downstream SQL Server table {}",
212                        col.name, self.config.table
213                    ))));
214                }
215                Some(sql_server_is_pk) => {
216                    if self.is_append_only {
217                        continue;
218                    }
219                    if rw_is_pk && !*sql_server_is_pk {
220                        return Err(SinkError::SqlServer(anyhow!(format!(
221                            "column {} specified in primary_key mismatches with the downstream SQL Server table {} PK",
222                            col.name, self.config.table,
223                        ))));
224                    }
225                    if !rw_is_pk && *sql_server_is_pk {
226                        return Err(SinkError::SqlServer(anyhow!(format!(
227                            "column {} unspecified in primary_key mismatches with the downstream SQL Server table {} PK",
228                            col.name, self.config.table,
229                        ))));
230                    }
231                }
232            }
233        }
234
235        if !self.is_append_only {
236            let sql_server_pk_count = sql_server_table_metadata
237                .values()
238                .filter(|is_pk| **is_pk)
239                .count();
240            if sql_server_pk_count != self.pk_indices.len() {
241                return Err(SinkError::SqlServer(anyhow!(format!(
242                    "primary key does not match between RisingWave sink ({}) and SQL Server table {} ({})",
243                    self.pk_indices.len(),
244                    self.config.table,
245                    sql_server_pk_count,
246                ))));
247            }
248        }
249
250        Ok(())
251    }
252
253    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
254        Ok(SqlServerSinkWriter::new(
255            self.config.clone(),
256            self.schema.clone(),
257            self.pk_indices.clone(),
258            self.is_append_only,
259        )
260        .await?
261        .into_log_sinker(SinkWriterMetrics::new(&writer_param)))
262    }
263}
264
265enum SqlOp {
266    Insert(OwnedRow),
267    Merge(OwnedRow),
268    Delete(OwnedRow),
269}
270
271pub struct SqlServerSinkWriter {
272    config: SqlServerConfig,
273    schema: Schema,
274    pk_indices: Vec<usize>,
275    is_append_only: bool,
276    sql_client: SqlServerClient,
277    ops: Vec<SqlOp>,
278}
279
280impl SqlServerSinkWriter {
281    async fn new(
282        config: SqlServerConfig,
283        schema: Schema,
284        pk_indices: Vec<usize>,
285        is_append_only: bool,
286    ) -> Result<Self> {
287        let sql_client = SqlServerClient::new(&config).await?;
288        let writer = Self {
289            config,
290            schema,
291            pk_indices,
292            is_append_only,
293            sql_client,
294            ops: vec![],
295        };
296        Ok(writer)
297    }
298
299    async fn delete_one(&mut self, row: RowRef<'_>) -> Result<()> {
300        if self.ops.len() + 1 >= self.config.max_batch_rows {
301            self.flush().await?;
302        }
303        self.ops.push(SqlOp::Delete(row.into_owned_row()));
304        Ok(())
305    }
306
307    async fn upsert_one(&mut self, row: RowRef<'_>) -> Result<()> {
308        if self.ops.len() + 1 >= self.config.max_batch_rows {
309            self.flush().await?;
310        }
311        self.ops.push(SqlOp::Merge(row.into_owned_row()));
312        Ok(())
313    }
314
315    async fn insert_one(&mut self, row: RowRef<'_>) -> Result<()> {
316        if self.ops.len() + 1 >= self.config.max_batch_rows {
317            self.flush().await?;
318        }
319        self.ops.push(SqlOp::Insert(row.into_owned_row()));
320        Ok(())
321    }
322
323    async fn flush(&mut self) -> Result<()> {
324        use std::fmt::Write;
325        if self.ops.is_empty() {
326            return Ok(());
327        }
328        let mut query_str = String::new();
329        let col_num = self.schema.fields.len();
330        let mut next_param_id = 1;
331        let non_pk_col_indices = (0..col_num)
332            .filter(|idx| !self.pk_indices.contains(idx))
333            .collect::<Vec<usize>>();
334        let all_col_names = self
335            .schema
336            .fields
337            .iter()
338            .map(|f| format!("[{}]", f.name))
339            .collect::<Vec<_>>()
340            .join(",");
341        let all_source_col_names = self
342            .schema
343            .fields
344            .iter()
345            .map(|f| format!("[SOURCE].[{}]", f.name))
346            .collect::<Vec<_>>()
347            .join(",");
348        let pk_match = self
349            .pk_indices
350            .iter()
351            .map(|idx| {
352                format!(
353                    "[SOURCE].[{}]=[TARGET].[{}]",
354                    &self.schema[*idx].name, &self.schema[*idx].name
355                )
356            })
357            .collect::<Vec<_>>()
358            .join(" AND ");
359        let param_placeholders = |param_id: &mut usize| {
360            let params = (*param_id..(*param_id + col_num))
361                .map(|i| format!("@P{}", i))
362                .collect::<Vec<_>>()
363                .join(",");
364            *param_id += col_num;
365            params
366        };
367        let set_all_source_col = non_pk_col_indices
368            .iter()
369            .map(|idx| {
370                format!(
371                    "[{}]=[SOURCE].[{}]",
372                    &self.schema[*idx].name, &self.schema[*idx].name
373                )
374            })
375            .collect::<Vec<_>>()
376            .join(",");
377        // TODO: avoid repeating the SQL
378        for op in &self.ops {
379            match op {
380                SqlOp::Insert(_) => {
381                    write!(
382                        &mut query_str,
383                        "INSERT INTO [{}] ({}) VALUES ({});",
384                        self.config.table,
385                        all_col_names,
386                        param_placeholders(&mut next_param_id),
387                    )
388                    .unwrap();
389                }
390                SqlOp::Merge(_) => {
391                    write!(
392                        &mut query_str,
393                        r#"MERGE [{}] AS [TARGET]
394                        USING (VALUES ({})) AS [SOURCE] ({})
395                        ON {}
396                        WHEN MATCHED THEN UPDATE SET {}
397                        WHEN NOT MATCHED THEN INSERT ({}) VALUES ({});"#,
398                        self.config.table,
399                        param_placeholders(&mut next_param_id),
400                        all_col_names,
401                        pk_match,
402                        set_all_source_col,
403                        all_col_names,
404                        all_source_col_names,
405                    )
406                    .unwrap();
407                }
408                SqlOp::Delete(_) => {
409                    write!(
410                        &mut query_str,
411                        r#"DELETE FROM [{}] WHERE {};"#,
412                        self.config.table,
413                        self.pk_indices
414                            .iter()
415                            .map(|idx| {
416                                let condition =
417                                    format!("[{}]=@P{}", self.schema[*idx].name, next_param_id);
418                                next_param_id += 1;
419                                condition
420                            })
421                            .collect::<Vec<_>>()
422                            .join(" AND "),
423                    )
424                    .unwrap();
425                }
426            }
427        }
428
429        let mut query = Query::new(query_str);
430        for op in self.ops.drain(..) {
431            match op {
432                SqlOp::Insert(row) => {
433                    bind_params(&mut query, row, &self.schema, 0..col_num)?;
434                }
435                SqlOp::Merge(row) => {
436                    bind_params(&mut query, row, &self.schema, 0..col_num)?;
437                }
438                SqlOp::Delete(row) => {
439                    bind_params(
440                        &mut query,
441                        row,
442                        &self.schema,
443                        self.pk_indices.iter().copied(),
444                    )?;
445                }
446            }
447        }
448        query.execute(&mut self.sql_client.inner_client).await?;
449        Ok(())
450    }
451}
452
453#[async_trait]
454impl SinkWriter for SqlServerSinkWriter {
455    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
456        Ok(())
457    }
458
459    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
460        for (op, row) in chunk.rows() {
461            match op {
462                Op::Insert => {
463                    if self.is_append_only {
464                        self.insert_one(row).await?;
465                    } else {
466                        self.upsert_one(row).await?;
467                    }
468                }
469                Op::UpdateInsert => {
470                    debug_assert!(!self.is_append_only);
471                    self.upsert_one(row).await?;
472                }
473                Op::Delete => {
474                    debug_assert!(!self.is_append_only);
475                    self.delete_one(row).await?;
476                }
477                Op::UpdateDelete => {}
478            }
479        }
480        Ok(())
481    }
482
483    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Self::CommitMetadata> {
484        if is_checkpoint {
485            self.flush().await?;
486        }
487        Ok(())
488    }
489}
490
491#[derive(Debug)]
492pub struct SqlServerClient {
493    pub inner_client: Client<tokio_util::compat::Compat<TcpStream>>,
494}
495
496impl SqlServerClient {
497    async fn new(msconfig: &SqlServerConfig) -> Result<Self> {
498        let mut config = Config::new();
499        config.host(&msconfig.host);
500        config.port(msconfig.port);
501        config.authentication(AuthMethod::sql_server(&msconfig.user, &msconfig.password));
502        config.database(&msconfig.database);
503        config.trust_cert();
504        Self::new_with_config(config).await
505    }
506
507    pub async fn new_with_config(mut config: Config) -> Result<Self> {
508        let tcp = TcpStream::connect(config.get_addr())
509            .await
510            .context("failed to connect to sql server")
511            .map_err(SinkError::SqlServer)?;
512        tcp.set_nodelay(true)
513            .context("failed to setting nodelay when connecting to sql server")
514            .map_err(SinkError::SqlServer)?;
515
516        let client = match Client::connect(config.clone(), tcp.compat_write()).await {
517            // Connection successful.
518            Ok(client) => client,
519            // The server wants us to redirect to a different address
520            Err(tiberius::error::Error::Routing { host, port }) => {
521                config.host(&host);
522                config.port(port);
523                let tcp = TcpStream::connect(config.get_addr())
524                    .await
525                    .context("failed to connect to sql server after routing")
526                    .map_err(SinkError::SqlServer)?;
527                tcp.set_nodelay(true)
528                    .context(
529                        "failed to setting nodelay when connecting to sql server after routing",
530                    )
531                    .map_err(SinkError::SqlServer)?;
532                // we should not have more than one redirect, so we'll short-circuit here.
533                Client::connect(config, tcp.compat_write()).await?
534            }
535            Err(e) => return Err(e.into()),
536        };
537
538        Ok(Self {
539            inner_client: client,
540        })
541    }
542}
543
544fn bind_params(
545    query: &mut Query<'_>,
546    row: impl Row,
547    schema: &Schema,
548    col_indices: impl Iterator<Item = usize>,
549) -> Result<()> {
550    use risingwave_common::types::ScalarRefImpl;
551    for col_idx in col_indices {
552        match row.datum_at(col_idx) {
553            Some(data_ref) => match data_ref {
554                ScalarRefImpl::Int16(v) => query.bind(v),
555                ScalarRefImpl::Int32(v) => query.bind(v),
556                ScalarRefImpl::Int64(v) => query.bind(v),
557                ScalarRefImpl::Float32(v) => query.bind(v.into_inner()),
558                ScalarRefImpl::Float64(v) => query.bind(v.into_inner()),
559                ScalarRefImpl::Utf8(v) => query.bind(v.to_owned()),
560                ScalarRefImpl::Bool(v) => query.bind(v),
561                ScalarRefImpl::Decimal(v) => match v {
562                    Decimal::Normalized(d) => {
563                        query.bind(decimal_to_sql(&d));
564                    }
565                    Decimal::NaN | Decimal::PositiveInf | Decimal::NegativeInf => {
566                        tracing::warn!(
567                            "Inf, -Inf, Nan in RisingWave decimal is converted into SQL Server null!"
568                        );
569                        query.bind(None as Option<Numeric>);
570                    }
571                },
572                ScalarRefImpl::Date(v) => query.bind(v.0),
573                ScalarRefImpl::Timestamp(v) => query.bind(v.0),
574                ScalarRefImpl::Timestamptz(v) => query.bind(v.timestamp_micros()),
575                ScalarRefImpl::Time(v) => query.bind(v.0),
576                ScalarRefImpl::Bytea(v) => query.bind(v.to_vec()),
577                ScalarRefImpl::Interval(_) => return Err(data_type_not_supported("Interval")),
578                ScalarRefImpl::Jsonb(_) => return Err(data_type_not_supported("Jsonb")),
579                ScalarRefImpl::Struct(_) => return Err(data_type_not_supported("Struct")),
580                ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")),
581                ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")),
582                ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")),
583                ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")),
584            },
585            None => match schema[col_idx].data_type {
586                DataType::Boolean => {
587                    query.bind(None as Option<bool>);
588                }
589                DataType::Int16 => {
590                    query.bind(None as Option<i16>);
591                }
592                DataType::Int32 => {
593                    query.bind(None as Option<i32>);
594                }
595                DataType::Int64 => {
596                    query.bind(None as Option<i64>);
597                }
598                DataType::Float32 => {
599                    query.bind(None as Option<f32>);
600                }
601                DataType::Float64 => {
602                    query.bind(None as Option<f64>);
603                }
604                DataType::Decimal => {
605                    query.bind(None as Option<Numeric>);
606                }
607                DataType::Date => {
608                    query.bind(None as Option<chrono::NaiveDate>);
609                }
610                DataType::Time => {
611                    query.bind(None as Option<chrono::NaiveTime>);
612                }
613                DataType::Timestamp => {
614                    query.bind(None as Option<chrono::NaiveDateTime>);
615                }
616                DataType::Timestamptz => {
617                    query.bind(None as Option<i64>);
618                }
619                DataType::Varchar => {
620                    query.bind(None as Option<String>);
621                }
622                DataType::Bytea => {
623                    query.bind(None as Option<Vec<u8>>);
624                }
625                DataType::Interval => return Err(data_type_not_supported("Interval")),
626                DataType::Struct(_) => return Err(data_type_not_supported("Struct")),
627                DataType::List(_) => return Err(data_type_not_supported("List")),
628                DataType::Jsonb => return Err(data_type_not_supported("Jsonb")),
629                DataType::Serial => return Err(data_type_not_supported("Serial")),
630                DataType::Int256 => return Err(data_type_not_supported("Int256")),
631                DataType::Map(_) => return Err(data_type_not_supported("Map")),
632            },
633        };
634    }
635    Ok(())
636}
637
638fn data_type_not_supported(data_type_name: &str) -> SinkError {
639    SinkError::SqlServer(anyhow!(format!(
640        "{data_type_name} is not supported in SQL Server"
641    )))
642}
643
644fn check_data_type_compatibility(data_type: &DataType) -> Result<()> {
645    match data_type {
646        DataType::Boolean
647        | DataType::Int16
648        | DataType::Int32
649        | DataType::Int64
650        | DataType::Float32
651        | DataType::Float64
652        | DataType::Decimal
653        | DataType::Date
654        | DataType::Varchar
655        | DataType::Time
656        | DataType::Timestamp
657        | DataType::Timestamptz
658        | DataType::Bytea => Ok(()),
659        DataType::Interval => Err(data_type_not_supported("Interval")),
660        DataType::Struct(_) => Err(data_type_not_supported("Struct")),
661        DataType::List(_) => Err(data_type_not_supported("List")),
662        DataType::Jsonb => Err(data_type_not_supported("Jsonb")),
663        DataType::Serial => Err(data_type_not_supported("Serial")),
664        DataType::Int256 => Err(data_type_not_supported("Int256")),
665        DataType::Map(_) => Err(data_type_not_supported("Map")),
666    }
667}
668
669/// The implementation is copied from tiberius crate.
670fn decimal_to_sql(decimal: &rust_decimal::Decimal) -> Numeric {
671    let unpacked = decimal.unpack();
672
673    let mut value = (((unpacked.hi as u128) << 64)
674        + ((unpacked.mid as u128) << 32)
675        + unpacked.lo as u128) as i128;
676
677    if decimal.is_sign_negative() {
678        value = -value;
679    }
680
681    Numeric::new_with_scale(value, decimal.scale() as u8)
682}