risingwave_connector/sink/
sqlserver.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use phf::{Set, phf_set};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::{OwnedRow, Row};
23use risingwave_common::types::{DataType, Decimal};
24use serde::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use simd_json::prelude::ArrayTrait;
27use tiberius::numeric::Numeric;
28use tiberius::{AuthMethod, Client, ColumnData, Config, Query};
29use tokio::net::TcpStream;
30use tokio_util::compat::TokioAsyncWriteCompatExt;
31use with_options::WithOptions;
32
33use super::{
34    SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkWriterMetrics,
35};
36use crate::enforce_secret::EnforceSecret;
37use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
38use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
39
40pub const SQLSERVER_SINK: &str = "sqlserver";
41
42fn default_max_batch_rows() -> usize {
43    1024
44}
45
46#[serde_as]
47#[derive(Clone, Debug, Deserialize, WithOptions)]
48pub struct SqlServerConfig {
49    #[serde(rename = "sqlserver.host")]
50    pub host: String,
51    #[serde(rename = "sqlserver.port")]
52    #[serde_as(as = "DisplayFromStr")]
53    pub port: u16,
54    #[serde(rename = "sqlserver.user")]
55    pub user: String,
56    #[serde(rename = "sqlserver.password")]
57    pub password: String,
58    #[serde(rename = "sqlserver.database")]
59    pub database: String,
60    #[serde(rename = "sqlserver.schema", default = "sql_server_default_schema")]
61    pub schema: String,
62    #[serde(rename = "sqlserver.table")]
63    pub table: String,
64    #[serde(
65        rename = "sqlserver.max_batch_rows",
66        default = "default_max_batch_rows"
67    )]
68    #[serde_as(as = "DisplayFromStr")]
69    pub max_batch_rows: usize,
70    pub r#type: String, // accept "append-only" or "upsert"
71}
72
73pub fn sql_server_default_schema() -> String {
74    "dbo".to_owned()
75}
76
77impl SqlServerConfig {
78    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
79        let config =
80            serde_json::from_value::<SqlServerConfig>(serde_json::to_value(properties).unwrap())
81                .map_err(|e| SinkError::Config(anyhow!(e)))?;
82        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
83            return Err(SinkError::Config(anyhow!(
84                "`{}` must be {}, or {}",
85                SINK_TYPE_OPTION,
86                SINK_TYPE_APPEND_ONLY,
87                SINK_TYPE_UPSERT
88            )));
89        }
90        Ok(config)
91    }
92
93    pub fn full_object_path(&self) -> String {
94        format!("[{}].[{}].[{}]", self.database, self.schema, self.table)
95    }
96}
97
98impl EnforceSecret for SqlServerConfig {
99    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
100        "sqlserver.password"
101    };
102}
103#[derive(Debug)]
104pub struct SqlServerSink {
105    pub config: SqlServerConfig,
106    schema: Schema,
107    pk_indices: Vec<usize>,
108    is_append_only: bool,
109}
110
111struct SqlServerColumnMetadata {
112    name: String,
113    is_pk: bool,
114    data_type: String,
115}
116
117impl EnforceSecret for SqlServerSink {
118    fn enforce_secret<'a>(
119        prop_iter: impl Iterator<Item = &'a str>,
120    ) -> crate::sink::ConnectorResult<()> {
121        for prop in prop_iter {
122            SqlServerConfig::enforce_one(prop)?;
123        }
124        Ok(())
125    }
126}
127impl SqlServerSink {
128    pub fn new(
129        mut config: SqlServerConfig,
130        schema: Schema,
131        pk_indices: Vec<usize>,
132        is_append_only: bool,
133    ) -> Result<Self> {
134        // Rewrite config because tiberius allows a maximum of 2100 params in one query request.
135        const TIBERIUS_PARAM_MAX: usize = 2000;
136        let params_per_op = schema.fields().len();
137        let tiberius_max_batch_rows = if params_per_op == 0 {
138            config.max_batch_rows
139        } else {
140            ((TIBERIUS_PARAM_MAX as f64 / params_per_op as f64).floor()) as usize
141        };
142        if tiberius_max_batch_rows == 0 {
143            return Err(SinkError::SqlServer(anyhow!(format!(
144                "too many column {}",
145                params_per_op
146            ))));
147        }
148        config.max_batch_rows = std::cmp::min(config.max_batch_rows, tiberius_max_batch_rows);
149        Ok(Self {
150            config,
151            schema,
152            pk_indices,
153            is_append_only,
154        })
155    }
156}
157
158impl TryFrom<SinkParam> for SqlServerSink {
159    type Error = SinkError;
160
161    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
162        let schema = param.schema();
163        let pk_indices = param.downstream_pk_or_empty();
164        let config = SqlServerConfig::from_btreemap(param.properties)?;
165        SqlServerSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
166    }
167}
168
169impl Sink for SqlServerSink {
170    type LogSinker = LogSinkerOf<SqlServerSinkWriter>;
171
172    const SINK_NAME: &'static str = SQLSERVER_SINK;
173
174    async fn validate(&self) -> Result<()> {
175        risingwave_common::license::Feature::SqlServerSink
176            .check_available()
177            .map_err(|e| anyhow::anyhow!(e))?;
178
179        if !self.is_append_only && self.pk_indices.is_empty() {
180            return Err(SinkError::Config(anyhow!(
181                "Primary key not defined for upsert SQL Server sink (please define in `primary_key` field)"
182            )));
183        }
184
185        for f in self.schema.fields() {
186            check_data_type_compatibility(&f.data_type)?;
187        }
188
189        let mut sql_client = SqlServerClient::new(&self.config).await?;
190        validate_sql_server_write_permission(&mut sql_client, &self.config, self.is_append_only)
191            .await?;
192        let sql_server_table_metadata =
193            query_sql_server_table_metadata(&mut sql_client, &self.config).await?;
194        let sql_server_pk_count = sql_server_table_metadata
195            .iter()
196            .filter(|metadata| metadata.is_pk)
197            .count();
198        let sql_server_table_metadata = sql_server_table_metadata
199            .into_iter()
200            .map(|metadata| (metadata.name.clone(), metadata))
201            .collect::<HashMap<_, _>>();
202
203        // Validate Column name, Primary Key and data type.
204        for (idx, col) in self.schema.fields().iter().enumerate() {
205            let rw_is_pk = self.pk_indices.contains(&idx);
206            match sql_server_table_metadata.get(&normalize_sql_server_column_name(&col.name)) {
207                None => {
208                    return Err(SinkError::SqlServer(anyhow!(format!(
209                        "column {} not found in the downstream SQL Server table {}",
210                        col.name,
211                        self.config.full_object_path()
212                    ))));
213                }
214                Some(sql_server_col) => {
215                    validate_data_type_compatibility(
216                        &col.name,
217                        &col.data_type,
218                        &sql_server_col.data_type,
219                    )?;
220                    if self.is_append_only {
221                        continue;
222                    }
223                    if rw_is_pk && !sql_server_col.is_pk {
224                        return Err(SinkError::SqlServer(anyhow!(format!(
225                            "column {} specified in primary_key mismatches with the downstream SQL Server table {} PK",
226                            col.name,
227                            self.config.full_object_path(),
228                        ))));
229                    }
230                    if !rw_is_pk && sql_server_col.is_pk {
231                        return Err(SinkError::SqlServer(anyhow!(format!(
232                            "column {} unspecified in primary_key mismatches with the downstream SQL Server table {} PK",
233                            col.name,
234                            self.config.full_object_path(),
235                        ))));
236                    }
237                }
238            }
239        }
240
241        if !self.is_append_only && sql_server_pk_count != self.pk_indices.len() {
242            let sql_server_pk_columns = sql_server_table_metadata
243                .values()
244                .filter(|metadata| metadata.is_pk)
245                .map(|metadata| metadata.name.as_str())
246                .collect::<Vec<_>>()
247                .join(",");
248            let rw_pk_columns = self
249                .pk_indices
250                .iter()
251                .map(|idx| self.schema[*idx].name.as_str())
252                .collect::<Vec<_>>()
253                .join(",");
254            return Err(SinkError::SqlServer(anyhow!(format!(
255                "primary key does not match between RisingWave sink ({}: [{}]) and SQL Server table {} ({}: [{}])",
256                self.pk_indices.len(),
257                rw_pk_columns,
258                self.config.full_object_path(),
259                sql_server_pk_count,
260                sql_server_pk_columns,
261            ))));
262        }
263
264        Ok(())
265    }
266
267    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
268        Ok(SqlServerSinkWriter::new(
269            self.config.clone(),
270            self.schema.clone(),
271            self.pk_indices.clone(),
272            self.is_append_only,
273        )
274        .await?
275        .into_log_sinker(SinkWriterMetrics::new(&writer_param)))
276    }
277}
278
279enum SqlOp {
280    Insert(OwnedRow),
281    Merge(OwnedRow),
282    Delete(OwnedRow),
283}
284
285pub struct SqlServerSinkWriter {
286    config: SqlServerConfig,
287    schema: Schema,
288    pk_indices: Vec<usize>,
289    is_append_only: bool,
290    downstream_column_data_types: Vec<String>,
291    sql_client: SqlServerClient,
292    ops: Vec<SqlOp>,
293}
294
295impl SqlServerSinkWriter {
296    async fn new(
297        config: SqlServerConfig,
298        schema: Schema,
299        pk_indices: Vec<usize>,
300        is_append_only: bool,
301    ) -> Result<Self> {
302        let mut sql_client = SqlServerClient::new(&config).await?;
303        let downstream_column_data_types =
304            query_downstream_column_metadata(&mut sql_client, &config, &schema)
305                .await?
306                .into_iter()
307                .map(|metadata| metadata.data_type)
308                .collect();
309        let writer = Self {
310            config,
311            schema,
312            pk_indices,
313            is_append_only,
314            downstream_column_data_types,
315            sql_client,
316            ops: vec![],
317        };
318        Ok(writer)
319    }
320
321    async fn delete_one(&mut self, row: RowRef<'_>) -> Result<()> {
322        if self.ops.len() + 1 >= self.config.max_batch_rows {
323            self.flush().await?;
324        }
325        self.ops.push(SqlOp::Delete(row.into_owned_row()));
326        Ok(())
327    }
328
329    async fn upsert_one(&mut self, row: RowRef<'_>) -> Result<()> {
330        if self.ops.len() + 1 >= self.config.max_batch_rows {
331            self.flush().await?;
332        }
333        self.ops.push(SqlOp::Merge(row.into_owned_row()));
334        Ok(())
335    }
336
337    async fn insert_one(&mut self, row: RowRef<'_>) -> Result<()> {
338        if self.ops.len() + 1 >= self.config.max_batch_rows {
339            self.flush().await?;
340        }
341        self.ops.push(SqlOp::Insert(row.into_owned_row()));
342        Ok(())
343    }
344
345    async fn flush(&mut self) -> Result<()> {
346        use std::fmt::Write;
347        if self.ops.is_empty() {
348            return Ok(());
349        }
350        let mut query_str = String::new();
351        let col_num = self.schema.fields.len();
352        let mut next_param_id = 1;
353        let non_pk_col_indices = (0..col_num)
354            .filter(|idx| !self.pk_indices.contains(idx))
355            .collect::<Vec<usize>>();
356        let all_col_names = self
357            .schema
358            .fields
359            .iter()
360            .map(|f| format!("[{}]", f.name))
361            .collect::<Vec<_>>()
362            .join(",");
363        let all_source_col_names = self
364            .schema
365            .fields
366            .iter()
367            .map(|f| format!("[SOURCE].[{}]", f.name))
368            .collect::<Vec<_>>()
369            .join(",");
370        let pk_match = self
371            .pk_indices
372            .iter()
373            .map(|idx| {
374                format!(
375                    "[SOURCE].[{}]=[TARGET].[{}]",
376                    &self.schema[*idx].name, &self.schema[*idx].name
377                )
378            })
379            .collect::<Vec<_>>()
380            .join(" AND ");
381        let param_placeholders = |param_id: &mut usize| {
382            (0..col_num)
383                .map(|_| param_placeholder(param_id))
384                .collect::<Vec<_>>()
385                .join(",")
386        };
387        let set_all_source_col = non_pk_col_indices
388            .iter()
389            .map(|idx| {
390                format!(
391                    "[{}]=[SOURCE].[{}]",
392                    &self.schema[*idx].name, &self.schema[*idx].name
393                )
394            })
395            .collect::<Vec<_>>()
396            .join(",");
397        // TODO: avoid repeating the SQL
398        for op in &self.ops {
399            match op {
400                SqlOp::Insert(_) => {
401                    write!(
402                        &mut query_str,
403                        "INSERT INTO {} ({}) VALUES ({});",
404                        self.config.full_object_path(),
405                        all_col_names,
406                        param_placeholders(&mut next_param_id),
407                    )
408                    .unwrap();
409                }
410                SqlOp::Merge(_) => {
411                    write!(
412                        &mut query_str,
413                        r#"MERGE {} WITH (HOLDLOCK) AS [TARGET]
414                        USING (VALUES ({})) AS [SOURCE] ({})
415                        ON {}
416                        WHEN MATCHED THEN UPDATE SET {}
417                        WHEN NOT MATCHED THEN INSERT ({}) VALUES ({});"#,
418                        self.config.full_object_path(),
419                        param_placeholders(&mut next_param_id),
420                        all_col_names,
421                        pk_match,
422                        set_all_source_col,
423                        all_col_names,
424                        all_source_col_names,
425                    )
426                    .unwrap();
427                }
428                SqlOp::Delete(_) => {
429                    write!(
430                        &mut query_str,
431                        r#"DELETE FROM {} WHERE {};"#,
432                        self.config.full_object_path(),
433                        self.pk_indices
434                            .iter()
435                            .map(|idx| {
436                                let condition = format!(
437                                    "[{}]={}",
438                                    self.schema[*idx].name,
439                                    param_placeholder(&mut next_param_id)
440                                );
441                                condition
442                            })
443                            .collect::<Vec<_>>()
444                            .join(" AND "),
445                    )
446                    .unwrap();
447                }
448            }
449        }
450
451        let mut query = Query::new(query_str);
452        for op in self.ops.drain(..) {
453            match op {
454                SqlOp::Insert(row) => {
455                    bind_params(
456                        &mut query,
457                        row,
458                        &self.schema,
459                        &self.downstream_column_data_types,
460                        0..col_num,
461                    )?;
462                }
463                SqlOp::Merge(row) => {
464                    bind_params(
465                        &mut query,
466                        row,
467                        &self.schema,
468                        &self.downstream_column_data_types,
469                        0..col_num,
470                    )?;
471                }
472                SqlOp::Delete(row) => {
473                    bind_params(
474                        &mut query,
475                        row,
476                        &self.schema,
477                        &self.downstream_column_data_types,
478                        self.pk_indices.iter().copied(),
479                    )?;
480                }
481            }
482        }
483        query.execute(&mut self.sql_client.inner_client).await?;
484        Ok(())
485    }
486}
487
488#[async_trait]
489impl SinkWriter for SqlServerSinkWriter {
490    async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
491        Ok(())
492    }
493
494    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
495        for (op, row) in chunk.rows() {
496            match op {
497                Op::Insert => {
498                    if self.is_append_only {
499                        self.insert_one(row).await?;
500                    } else {
501                        self.upsert_one(row).await?;
502                    }
503                }
504                Op::UpdateInsert => {
505                    debug_assert!(!self.is_append_only);
506                    self.upsert_one(row).await?;
507                }
508                Op::Delete => {
509                    debug_assert!(!self.is_append_only);
510                    self.delete_one(row).await?;
511                }
512                Op::UpdateDelete => {}
513            }
514        }
515        Ok(())
516    }
517
518    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Self::CommitMetadata> {
519        if is_checkpoint {
520            self.flush().await?;
521        }
522        Ok(())
523    }
524}
525
526#[derive(Debug)]
527pub struct SqlServerClient {
528    pub inner_client: Client<tokio_util::compat::Compat<TcpStream>>,
529}
530
531impl SqlServerClient {
532    async fn new(msconfig: &SqlServerConfig) -> Result<Self> {
533        let mut config = Config::new();
534        config.host(&msconfig.host);
535        config.port(msconfig.port);
536        config.authentication(AuthMethod::sql_server(&msconfig.user, &msconfig.password));
537        config.database(&msconfig.database);
538        config.trust_cert();
539        Self::new_with_config(config).await
540    }
541
542    pub async fn new_with_config(mut config: Config) -> Result<Self> {
543        let tcp = TcpStream::connect(config.get_addr())
544            .await
545            .context("failed to connect to sql server")
546            .map_err(SinkError::SqlServer)?;
547        tcp.set_nodelay(true)
548            .context("failed to setting nodelay when connecting to sql server")
549            .map_err(SinkError::SqlServer)?;
550
551        let client = match Client::connect(config.clone(), tcp.compat_write()).await {
552            // Connection successful.
553            Ok(client) => client,
554            // The server wants us to redirect to a different address
555            Err(tiberius::error::Error::Routing { host, port }) => {
556                config.host(&host);
557                config.port(port);
558                let tcp = TcpStream::connect(config.get_addr())
559                    .await
560                    .context("failed to connect to sql server after routing")
561                    .map_err(SinkError::SqlServer)?;
562                tcp.set_nodelay(true)
563                    .context(
564                        "failed to setting nodelay when connecting to sql server after routing",
565                    )
566                    .map_err(SinkError::SqlServer)?;
567                // we should not have more than one redirect, so we'll short-circuit here.
568                Client::connect(config, tcp.compat_write()).await?
569            }
570            Err(e) => return Err(e.into()),
571        };
572
573        Ok(Self {
574            inner_client: client,
575        })
576    }
577}
578
579async fn query_sql_server_table_metadata(
580    sql_client: &mut SqlServerClient,
581    config: &SqlServerConfig,
582) -> Result<Vec<SqlServerColumnMetadata>> {
583    let mut sql_server_table_metadata = Vec::new();
584    let query_table_metadata_error = || {
585        SinkError::SqlServer(anyhow!(format!(
586            "SQL Server table {} metadata error",
587            config.full_object_path()
588        )))
589    };
590    // Query primary-key membership through a subquery filtered by `pk.is_primary_key = 1`.
591    // A column can appear in both the primary-key index and secondary indexes, and a naive
592    // join from `sys.columns` to all `sys.index_columns` would emit extra index rows or mark
593    // secondary-index-only columns as PK columns. Keep the PK filter inside the subquery so
594    // each table column is returned once with `IsPk` set only by the primary-key index.
595    static QUERY_TABLE_METADATA: &str = r#"
596SELECT
597    col.name AS ColumnName,
598    CAST(CASE WHEN pk_col.column_id IS NULL THEN 0 ELSE 1 END AS int) AS IsPk,
599    typ.name AS DataType
600FROM
601    sys.columns col
602JOIN
603    sys.types typ ON typ.user_type_id = col.user_type_id
604LEFT JOIN
605    (
606        SELECT ic.object_id, ic.column_id
607        FROM sys.indexes pk
608        JOIN sys.index_columns ic ON ic.object_id = pk.object_id AND ic.index_id = pk.index_id
609        WHERE pk.is_primary_key = 1
610    ) pk_col ON pk_col.object_id = col.object_id AND pk_col.column_id = col.column_id
611WHERE
612    col.object_id = OBJECT_ID(@P1)
613ORDER BY
614    col.column_id;"#;
615    let rows = sql_client
616        .inner_client
617        .query(QUERY_TABLE_METADATA, &[&config.full_object_path()])
618        .await?
619        .into_results()
620        .await?;
621    for row in rows.into_iter().flatten() {
622        let mut iter = row.into_iter();
623        let ColumnData::String(Some(col_name)) =
624            iter.next().ok_or_else(query_table_metadata_error)?
625        else {
626            return Err(query_table_metadata_error());
627        };
628        let ColumnData::I32(Some(col_is_pk)) =
629            iter.next().ok_or_else(query_table_metadata_error)?
630        else {
631            return Err(query_table_metadata_error());
632        };
633        let ColumnData::String(Some(data_type)) =
634            iter.next().ok_or_else(query_table_metadata_error)?
635        else {
636            return Err(query_table_metadata_error());
637        };
638        sql_server_table_metadata.push(SqlServerColumnMetadata {
639            name: normalize_sql_server_column_name(&col_name),
640            is_pk: col_is_pk != 0,
641            data_type: data_type.into_owned(),
642        });
643    }
644    Ok(sql_server_table_metadata)
645}
646
647async fn validate_sql_server_write_permission(
648    sql_client: &mut SqlServerClient,
649    config: &SqlServerConfig,
650    is_append_only: bool,
651) -> Result<()> {
652    let permission_query_error = || {
653        SinkError::SqlServer(anyhow!(format!(
654            "SQL Server table {} permission metadata error",
655            config.full_object_path()
656        )))
657    };
658    static QUERY_WRITE_PERMISSION: &str = r#"
659SELECT
660    CAST(HAS_PERMS_BY_NAME(@P1, 'OBJECT', 'INSERT') AS int) AS CanInsert,
661    CAST(HAS_PERMS_BY_NAME(@P1, 'OBJECT', 'UPDATE') AS int) AS CanUpdate,
662    CAST(HAS_PERMS_BY_NAME(@P1, 'OBJECT', 'DELETE') AS int) AS CanDelete;"#;
663    let rows = sql_client
664        .inner_client
665        .query(QUERY_WRITE_PERMISSION, &[&config.full_object_path()])
666        .await?
667        .into_results()
668        .await?;
669    let mut rows = rows.into_iter().flatten();
670    let row = rows.next().ok_or_else(permission_query_error)?;
671    let mut iter = row.into_iter();
672    let ColumnData::I32(can_insert) = iter.next().ok_or_else(permission_query_error)? else {
673        return Err(permission_query_error());
674    };
675    let ColumnData::I32(can_update) = iter.next().ok_or_else(permission_query_error)? else {
676        return Err(permission_query_error());
677    };
678    let ColumnData::I32(can_delete) = iter.next().ok_or_else(permission_query_error)? else {
679        return Err(permission_query_error());
680    };
681
682    let missing_permissions = missing_sql_server_write_permissions(
683        is_append_only,
684        permission_is_granted(can_insert),
685        permission_is_granted(can_update),
686        permission_is_granted(can_delete),
687    );
688    if missing_permissions.is_empty() {
689        return Ok(());
690    }
691
692    Err(SinkError::SqlServer(anyhow!(format!(
693        "SQL Server user {} lacks required write permission(s) {} on table {}",
694        config.user,
695        missing_permissions.join(", "),
696        config.full_object_path()
697    ))))
698}
699
700fn permission_is_granted(permission_value: Option<i32>) -> bool {
701    permission_value == Some(1)
702}
703
704fn missing_sql_server_write_permissions(
705    is_append_only: bool,
706    can_insert: bool,
707    can_update: bool,
708    can_delete: bool,
709) -> Vec<&'static str> {
710    let mut missing_permissions = vec![];
711    if !can_insert {
712        missing_permissions.push("INSERT");
713    }
714    if !is_append_only {
715        if !can_update {
716            missing_permissions.push("UPDATE");
717        }
718        if !can_delete {
719            missing_permissions.push("DELETE");
720        }
721    }
722    missing_permissions
723}
724
725async fn query_downstream_column_metadata(
726    sql_client: &mut SqlServerClient,
727    config: &SqlServerConfig,
728    schema: &Schema,
729) -> Result<Vec<SqlServerColumnMetadata>> {
730    let sql_server_table_metadata = query_sql_server_table_metadata(sql_client, config)
731        .await?
732        .into_iter()
733        .map(|metadata| (metadata.name.clone(), metadata))
734        .collect::<HashMap<_, _>>();
735    schema
736        .fields()
737        .iter()
738        .map(|col| {
739            sql_server_table_metadata
740                .get(&normalize_sql_server_column_name(&col.name))
741                .map(|metadata| SqlServerColumnMetadata {
742                    name: metadata.name.clone(),
743                    is_pk: metadata.is_pk,
744                    data_type: metadata.data_type.clone(),
745                })
746                .ok_or_else(|| {
747                    SinkError::SqlServer(anyhow!(format!(
748                        "column {} not found in the downstream SQL Server table {}",
749                        col.name,
750                        config.full_object_path()
751                    )))
752                })
753        })
754        .collect()
755}
756
757fn param_placeholder(param_id: &mut usize) -> String {
758    let placeholder = format!("@P{}", *param_id);
759    *param_id += 1;
760    placeholder
761}
762
763fn bind_params(
764    query: &mut Query<'_>,
765    row: impl Row,
766    schema: &Schema,
767    downstream_column_data_types: &[String],
768    col_indices: impl Iterator<Item = usize>,
769) -> Result<()> {
770    use risingwave_common::types::ScalarRefImpl;
771    for col_idx in col_indices {
772        match row.datum_at(col_idx) {
773            Some(data_ref) => match data_ref {
774                ScalarRefImpl::Int16(v) => query.bind(v),
775                ScalarRefImpl::Int32(v) => query.bind(v),
776                ScalarRefImpl::Int64(v) => query.bind(v),
777                ScalarRefImpl::Float32(v) => query.bind(v.into_inner()),
778                ScalarRefImpl::Float64(v) => query.bind(v.into_inner()),
779                ScalarRefImpl::Utf8(v) => query.bind(v.to_owned()),
780                ScalarRefImpl::Bool(v) => query.bind(v),
781                ScalarRefImpl::Decimal(v) => match v {
782                    Decimal::Normalized(d) => {
783                        query.bind(decimal_to_sql(&d));
784                    }
785                    Decimal::NaN | Decimal::PositiveInf | Decimal::NegativeInf => {
786                        tracing::warn!(
787                            "Inf, -Inf, Nan in RisingWave decimal is converted into SQL Server null!"
788                        );
789                        query.bind(None as Option<Numeric>);
790                    }
791                },
792                ScalarRefImpl::Date(v) => query.bind(v.0),
793                ScalarRefImpl::Timestamp(v) => query.bind(v.0),
794                ScalarRefImpl::Timestamptz(v) => {
795                    let downstream_data_type = &downstream_column_data_types[col_idx];
796                    match downstream_data_type.as_str() {
797                        "bigint" | "int" | "smallint" | "tinyint" => {
798                            query.bind(v.timestamp_micros());
799                        }
800                        "datetimeoffset" => {
801                            query.bind(v.to_datetime_utc().fixed_offset());
802                        }
803                        "datetime" | "datetime2" | "smalldatetime" => {
804                            query.bind(v.to_datetime_utc().naive_utc());
805                        }
806                        _ => {
807                            return Err(unexpected_downstream_timestamptz_type(
808                                downstream_data_type,
809                            ));
810                        }
811                    };
812                }
813                ScalarRefImpl::Time(v) => query.bind(v.0),
814                ScalarRefImpl::Bytea(v) => query.bind(v.to_vec()),
815                ScalarRefImpl::Interval(_) => return Err(data_type_not_supported("Interval")),
816                ScalarRefImpl::Jsonb(_) => return Err(data_type_not_supported("Jsonb")),
817                ScalarRefImpl::Struct(_) => return Err(data_type_not_supported("Struct")),
818                ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")),
819                ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")),
820                ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")),
821                ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")),
822                ScalarRefImpl::Vector(_) => return Err(data_type_not_supported("Vector")),
823            },
824            None => match schema[col_idx].data_type {
825                DataType::Boolean => {
826                    query.bind(None as Option<bool>);
827                }
828                DataType::Int16 => {
829                    query.bind(None as Option<i16>);
830                }
831                DataType::Int32 => {
832                    query.bind(None as Option<i32>);
833                }
834                DataType::Int64 => {
835                    query.bind(None as Option<i64>);
836                }
837                DataType::Float32 => {
838                    query.bind(None as Option<f32>);
839                }
840                DataType::Float64 => {
841                    query.bind(None as Option<f64>);
842                }
843                DataType::Decimal => {
844                    query.bind(None as Option<Numeric>);
845                }
846                DataType::Date => {
847                    query.bind(None as Option<chrono::NaiveDate>);
848                }
849                DataType::Time => {
850                    query.bind(None as Option<chrono::NaiveTime>);
851                }
852                DataType::Timestamp => {
853                    query.bind(None as Option<chrono::NaiveDateTime>);
854                }
855                DataType::Timestamptz => {
856                    let downstream_data_type = &downstream_column_data_types[col_idx];
857                    match downstream_data_type.as_str() {
858                        "bigint" | "int" | "smallint" | "tinyint" => {
859                            query.bind(None as Option<i64>);
860                        }
861                        "datetimeoffset" => {
862                            query.bind(None as Option<chrono::DateTime<chrono::FixedOffset>>);
863                        }
864                        "datetime" | "datetime2" | "smalldatetime" => {
865                            query.bind(None as Option<chrono::NaiveDateTime>);
866                        }
867                        _ => {
868                            return Err(unexpected_downstream_timestamptz_type(
869                                downstream_data_type,
870                            ));
871                        }
872                    };
873                }
874                DataType::Varchar => {
875                    query.bind(None as Option<String>);
876                }
877                DataType::Bytea => {
878                    query.bind(None as Option<Vec<u8>>);
879                }
880                DataType::Interval => return Err(data_type_not_supported("Interval")),
881                DataType::Struct(_) => return Err(data_type_not_supported("Struct")),
882                DataType::List(_) => return Err(data_type_not_supported("List")),
883                DataType::Jsonb => return Err(data_type_not_supported("Jsonb")),
884                DataType::Serial => return Err(data_type_not_supported("Serial")),
885                DataType::Int256 => return Err(data_type_not_supported("Int256")),
886                DataType::Map(_) => return Err(data_type_not_supported("Map")),
887                DataType::Vector(_) => return Err(data_type_not_supported("Vector")),
888            },
889        };
890    }
891    Ok(())
892}
893
894fn data_type_not_supported(data_type_name: &str) -> SinkError {
895    SinkError::SqlServer(anyhow!(format!(
896        "{data_type_name} is not supported in SQL Server"
897    )))
898}
899
900fn unexpected_downstream_timestamptz_type(sql_server_data_type: &str) -> SinkError {
901    SinkError::SqlServer(anyhow!(format!(
902        "unexpected downstream SQL Server type {sql_server_data_type} for Timestamptz"
903    )))
904}
905
906fn check_data_type_compatibility(data_type: &DataType) -> Result<()> {
907    match data_type {
908        DataType::Boolean
909        | DataType::Int16
910        | DataType::Int32
911        | DataType::Int64
912        | DataType::Float32
913        | DataType::Float64
914        | DataType::Decimal
915        | DataType::Date
916        | DataType::Varchar
917        | DataType::Time
918        | DataType::Timestamp
919        | DataType::Timestamptz
920        | DataType::Bytea => Ok(()),
921        DataType::Interval => Err(data_type_not_supported("Interval")),
922        DataType::Struct(_) => Err(data_type_not_supported("Struct")),
923        DataType::List(_) => Err(data_type_not_supported("List")),
924        DataType::Jsonb => Err(data_type_not_supported("Jsonb")),
925        DataType::Serial => Err(data_type_not_supported("Serial")),
926        DataType::Int256 => Err(data_type_not_supported("Int256")),
927        DataType::Map(_) => Err(data_type_not_supported("Map")),
928        DataType::Vector(_) => Err(data_type_not_supported("Vector")),
929    }
930}
931
932fn normalize_sql_server_column_name(column_name: &str) -> String {
933    // SQL Server identifiers are usually case-insensitive depending on database collation.
934    // Match metadata by a case-insensitive key so validation follows that common behavior.
935    column_name.to_lowercase()
936}
937
938fn validate_data_type_compatibility(
939    column_name: &str,
940    rw_data_type: &DataType,
941    sql_server_data_type: &str,
942) -> Result<()> {
943    if sql_server_data_type_is_compatible(rw_data_type, sql_server_data_type) {
944        return Ok(());
945    }
946
947    Err(SinkError::SqlServer(anyhow!(format!(
948        "column {} data type {:?} is incompatible with downstream SQL Server type {}",
949        column_name, rw_data_type, sql_server_data_type
950    ))))
951}
952
953fn sql_server_data_type_is_compatible(rw_data_type: &DataType, sql_server_data_type: &str) -> bool {
954    match rw_data_type {
955        DataType::Boolean => sql_server_data_type == "bit",
956        DataType::Int16 => matches!(sql_server_data_type, "smallint" | "int" | "bigint"),
957        DataType::Int32 => matches!(sql_server_data_type, "int" | "bigint"),
958        DataType::Int64 => sql_server_data_type == "bigint",
959        DataType::Float32 => matches!(sql_server_data_type, "real" | "float"),
960        DataType::Float64 => sql_server_data_type == "float",
961        DataType::Decimal => matches!(sql_server_data_type, "decimal" | "numeric"),
962        DataType::Date => sql_server_data_type == "date",
963        DataType::Varchar => matches!(
964            sql_server_data_type,
965            "char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext"
966        ),
967        DataType::Time => sql_server_data_type == "time",
968        DataType::Timestamp => {
969            matches!(
970                sql_server_data_type,
971                "datetime" | "datetime2" | "smalldatetime"
972            )
973        }
974        DataType::Timestamptz => matches!(
975            sql_server_data_type,
976            "datetimeoffset"
977                | "datetime"
978                | "datetime2"
979                | "smalldatetime"
980                | "bigint"
981                | "int"
982                | "smallint"
983                | "tinyint"
984        ),
985        DataType::Bytea => matches!(sql_server_data_type, "binary" | "varbinary" | "image"),
986        DataType::Interval
987        | DataType::Struct(_)
988        | DataType::List(_)
989        | DataType::Jsonb
990        | DataType::Serial
991        | DataType::Int256
992        | DataType::Map(_)
993        | DataType::Vector(_) => false,
994    }
995}
996
997/// The implementation is copied from tiberius crate.
998fn decimal_to_sql(decimal: &rust_decimal::Decimal) -> Numeric {
999    let unpacked = decimal.unpack();
1000
1001    let mut value = (((unpacked.hi as u128) << 64)
1002        + ((unpacked.mid as u128) << 32)
1003        + unpacked.lo as u128) as i128;
1004
1005    if decimal.is_sign_negative() {
1006        value = -value;
1007    }
1008
1009    Numeric::new_with_scale(value, decimal.scale() as u8)
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015
1016    #[test]
1017    fn test_normalize_sql_server_column_name() {
1018        assert_eq!(normalize_sql_server_column_name("EventDate"), "eventdate");
1019    }
1020
1021    #[test]
1022    fn test_sql_server_data_type_compatibility() {
1023        assert!(sql_server_data_type_is_compatible(
1024            &DataType::Int16,
1025            "smallint"
1026        ));
1027        assert!(sql_server_data_type_is_compatible(&DataType::Int16, "int"));
1028        assert!(!sql_server_data_type_is_compatible(
1029            &DataType::Int32,
1030            "smallint"
1031        ));
1032
1033        assert!(sql_server_data_type_is_compatible(
1034            &DataType::Timestamp,
1035            "datetime2"
1036        ));
1037        assert!(sql_server_data_type_is_compatible(
1038            &DataType::Timestamptz,
1039            "datetimeoffset"
1040        ));
1041        assert!(sql_server_data_type_is_compatible(
1042            &DataType::Timestamptz,
1043            "datetime2"
1044        ));
1045        assert!(sql_server_data_type_is_compatible(
1046            &DataType::Timestamptz,
1047            "bigint"
1048        ));
1049        assert!(sql_server_data_type_is_compatible(
1050            &DataType::Timestamptz,
1051            "int"
1052        ));
1053        assert!(!sql_server_data_type_is_compatible(
1054            &DataType::Timestamp,
1055            "datetimeoffset"
1056        ));
1057
1058        assert!(sql_server_data_type_is_compatible(
1059            &DataType::Varchar,
1060            "nvarchar"
1061        ));
1062        assert!(!sql_server_data_type_is_compatible(
1063            &DataType::Varchar,
1064            "uniqueidentifier"
1065        ));
1066    }
1067
1068    #[test]
1069    fn test_missing_sql_server_write_permissions() {
1070        assert_eq!(
1071            missing_sql_server_write_permissions(true, false, false, false),
1072            vec!["INSERT"]
1073        );
1074        assert!(missing_sql_server_write_permissions(true, true, false, false).is_empty());
1075        assert_eq!(
1076            missing_sql_server_write_permissions(false, false, false, false),
1077            vec!["INSERT", "UPDATE", "DELETE"]
1078        );
1079        assert_eq!(
1080            missing_sql_server_write_permissions(false, true, false, true),
1081            vec!["UPDATE"]
1082        );
1083        assert!(missing_sql_server_write_permissions(false, true, true, true).is_empty());
1084    }
1085}