1use std::collections::{BTreeMap, HashMap};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use phf::{Set, phf_set};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::{OwnedRow, Row};
23use risingwave_common::types::{DataType, Decimal};
24use serde::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use simd_json::prelude::ArrayTrait;
27use tiberius::numeric::Numeric;
28use tiberius::{AuthMethod, Client, ColumnData, Config, Query};
29use tokio::net::TcpStream;
30use tokio_util::compat::TokioAsyncWriteCompatExt;
31use with_options::WithOptions;
32
33use super::{
34 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkWriterMetrics,
35};
36use crate::enforce_secret::EnforceSecret;
37use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
38use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
39
40pub const SQLSERVER_SINK: &str = "sqlserver";
41
42fn default_max_batch_rows() -> usize {
43 1024
44}
45
46#[serde_as]
47#[derive(Clone, Debug, Deserialize, WithOptions)]
48pub struct SqlServerConfig {
49 #[serde(rename = "sqlserver.host")]
50 pub host: String,
51 #[serde(rename = "sqlserver.port")]
52 #[serde_as(as = "DisplayFromStr")]
53 pub port: u16,
54 #[serde(rename = "sqlserver.user")]
55 pub user: String,
56 #[serde(rename = "sqlserver.password")]
57 pub password: String,
58 #[serde(rename = "sqlserver.database")]
59 pub database: String,
60 #[serde(rename = "sqlserver.schema", default = "sql_server_default_schema")]
61 pub schema: String,
62 #[serde(rename = "sqlserver.table")]
63 pub table: String,
64 #[serde(
65 rename = "sqlserver.max_batch_rows",
66 default = "default_max_batch_rows"
67 )]
68 #[serde_as(as = "DisplayFromStr")]
69 pub max_batch_rows: usize,
70 pub r#type: String, }
72
73pub fn sql_server_default_schema() -> String {
74 "dbo".to_owned()
75}
76
77impl SqlServerConfig {
78 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
79 let config =
80 serde_json::from_value::<SqlServerConfig>(serde_json::to_value(properties).unwrap())
81 .map_err(|e| SinkError::Config(anyhow!(e)))?;
82 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
83 return Err(SinkError::Config(anyhow!(
84 "`{}` must be {}, or {}",
85 SINK_TYPE_OPTION,
86 SINK_TYPE_APPEND_ONLY,
87 SINK_TYPE_UPSERT
88 )));
89 }
90 Ok(config)
91 }
92
93 pub fn full_object_path(&self) -> String {
94 format!("[{}].[{}].[{}]", self.database, self.schema, self.table)
95 }
96}
97
98impl EnforceSecret for SqlServerConfig {
99 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
100 "sqlserver.password"
101 };
102}
103#[derive(Debug)]
104pub struct SqlServerSink {
105 pub config: SqlServerConfig,
106 schema: Schema,
107 pk_indices: Vec<usize>,
108 is_append_only: bool,
109}
110
111struct SqlServerColumnMetadata {
112 name: String,
113 is_pk: bool,
114 data_type: String,
115}
116
117impl EnforceSecret for SqlServerSink {
118 fn enforce_secret<'a>(
119 prop_iter: impl Iterator<Item = &'a str>,
120 ) -> crate::sink::ConnectorResult<()> {
121 for prop in prop_iter {
122 SqlServerConfig::enforce_one(prop)?;
123 }
124 Ok(())
125 }
126}
127impl SqlServerSink {
128 pub fn new(
129 mut config: SqlServerConfig,
130 schema: Schema,
131 pk_indices: Vec<usize>,
132 is_append_only: bool,
133 ) -> Result<Self> {
134 const TIBERIUS_PARAM_MAX: usize = 2000;
136 let params_per_op = schema.fields().len();
137 let tiberius_max_batch_rows = if params_per_op == 0 {
138 config.max_batch_rows
139 } else {
140 ((TIBERIUS_PARAM_MAX as f64 / params_per_op as f64).floor()) as usize
141 };
142 if tiberius_max_batch_rows == 0 {
143 return Err(SinkError::SqlServer(anyhow!(format!(
144 "too many column {}",
145 params_per_op
146 ))));
147 }
148 config.max_batch_rows = std::cmp::min(config.max_batch_rows, tiberius_max_batch_rows);
149 Ok(Self {
150 config,
151 schema,
152 pk_indices,
153 is_append_only,
154 })
155 }
156}
157
158impl TryFrom<SinkParam> for SqlServerSink {
159 type Error = SinkError;
160
161 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
162 let schema = param.schema();
163 let pk_indices = param.downstream_pk_or_empty();
164 let config = SqlServerConfig::from_btreemap(param.properties)?;
165 SqlServerSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
166 }
167}
168
169impl Sink for SqlServerSink {
170 type LogSinker = LogSinkerOf<SqlServerSinkWriter>;
171
172 const SINK_NAME: &'static str = SQLSERVER_SINK;
173
174 async fn validate(&self) -> Result<()> {
175 risingwave_common::license::Feature::SqlServerSink
176 .check_available()
177 .map_err(|e| anyhow::anyhow!(e))?;
178
179 if !self.is_append_only && self.pk_indices.is_empty() {
180 return Err(SinkError::Config(anyhow!(
181 "Primary key not defined for upsert SQL Server sink (please define in `primary_key` field)"
182 )));
183 }
184
185 for f in self.schema.fields() {
186 check_data_type_compatibility(&f.data_type)?;
187 }
188
189 let mut sql_client = SqlServerClient::new(&self.config).await?;
190 validate_sql_server_write_permission(&mut sql_client, &self.config, self.is_append_only)
191 .await?;
192 let sql_server_table_metadata =
193 query_sql_server_table_metadata(&mut sql_client, &self.config).await?;
194 let sql_server_pk_count = sql_server_table_metadata
195 .iter()
196 .filter(|metadata| metadata.is_pk)
197 .count();
198 let sql_server_table_metadata = sql_server_table_metadata
199 .into_iter()
200 .map(|metadata| (metadata.name.clone(), metadata))
201 .collect::<HashMap<_, _>>();
202
203 for (idx, col) in self.schema.fields().iter().enumerate() {
205 let rw_is_pk = self.pk_indices.contains(&idx);
206 match sql_server_table_metadata.get(&normalize_sql_server_column_name(&col.name)) {
207 None => {
208 return Err(SinkError::SqlServer(anyhow!(format!(
209 "column {} not found in the downstream SQL Server table {}",
210 col.name,
211 self.config.full_object_path()
212 ))));
213 }
214 Some(sql_server_col) => {
215 validate_data_type_compatibility(
216 &col.name,
217 &col.data_type,
218 &sql_server_col.data_type,
219 )?;
220 if self.is_append_only {
221 continue;
222 }
223 if rw_is_pk && !sql_server_col.is_pk {
224 return Err(SinkError::SqlServer(anyhow!(format!(
225 "column {} specified in primary_key mismatches with the downstream SQL Server table {} PK",
226 col.name,
227 self.config.full_object_path(),
228 ))));
229 }
230 if !rw_is_pk && sql_server_col.is_pk {
231 return Err(SinkError::SqlServer(anyhow!(format!(
232 "column {} unspecified in primary_key mismatches with the downstream SQL Server table {} PK",
233 col.name,
234 self.config.full_object_path(),
235 ))));
236 }
237 }
238 }
239 }
240
241 if !self.is_append_only && sql_server_pk_count != self.pk_indices.len() {
242 let sql_server_pk_columns = sql_server_table_metadata
243 .values()
244 .filter(|metadata| metadata.is_pk)
245 .map(|metadata| metadata.name.as_str())
246 .collect::<Vec<_>>()
247 .join(",");
248 let rw_pk_columns = self
249 .pk_indices
250 .iter()
251 .map(|idx| self.schema[*idx].name.as_str())
252 .collect::<Vec<_>>()
253 .join(",");
254 return Err(SinkError::SqlServer(anyhow!(format!(
255 "primary key does not match between RisingWave sink ({}: [{}]) and SQL Server table {} ({}: [{}])",
256 self.pk_indices.len(),
257 rw_pk_columns,
258 self.config.full_object_path(),
259 sql_server_pk_count,
260 sql_server_pk_columns,
261 ))));
262 }
263
264 Ok(())
265 }
266
267 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
268 Ok(SqlServerSinkWriter::new(
269 self.config.clone(),
270 self.schema.clone(),
271 self.pk_indices.clone(),
272 self.is_append_only,
273 )
274 .await?
275 .into_log_sinker(SinkWriterMetrics::new(&writer_param)))
276 }
277}
278
279enum SqlOp {
280 Insert(OwnedRow),
281 Merge(OwnedRow),
282 Delete(OwnedRow),
283}
284
285pub struct SqlServerSinkWriter {
286 config: SqlServerConfig,
287 schema: Schema,
288 pk_indices: Vec<usize>,
289 is_append_only: bool,
290 downstream_column_data_types: Vec<String>,
291 sql_client: SqlServerClient,
292 ops: Vec<SqlOp>,
293}
294
295impl SqlServerSinkWriter {
296 async fn new(
297 config: SqlServerConfig,
298 schema: Schema,
299 pk_indices: Vec<usize>,
300 is_append_only: bool,
301 ) -> Result<Self> {
302 let mut sql_client = SqlServerClient::new(&config).await?;
303 let downstream_column_data_types =
304 query_downstream_column_metadata(&mut sql_client, &config, &schema)
305 .await?
306 .into_iter()
307 .map(|metadata| metadata.data_type)
308 .collect();
309 let writer = Self {
310 config,
311 schema,
312 pk_indices,
313 is_append_only,
314 downstream_column_data_types,
315 sql_client,
316 ops: vec![],
317 };
318 Ok(writer)
319 }
320
321 async fn delete_one(&mut self, row: RowRef<'_>) -> Result<()> {
322 if self.ops.len() + 1 >= self.config.max_batch_rows {
323 self.flush().await?;
324 }
325 self.ops.push(SqlOp::Delete(row.into_owned_row()));
326 Ok(())
327 }
328
329 async fn upsert_one(&mut self, row: RowRef<'_>) -> Result<()> {
330 if self.ops.len() + 1 >= self.config.max_batch_rows {
331 self.flush().await?;
332 }
333 self.ops.push(SqlOp::Merge(row.into_owned_row()));
334 Ok(())
335 }
336
337 async fn insert_one(&mut self, row: RowRef<'_>) -> Result<()> {
338 if self.ops.len() + 1 >= self.config.max_batch_rows {
339 self.flush().await?;
340 }
341 self.ops.push(SqlOp::Insert(row.into_owned_row()));
342 Ok(())
343 }
344
345 async fn flush(&mut self) -> Result<()> {
346 use std::fmt::Write;
347 if self.ops.is_empty() {
348 return Ok(());
349 }
350 let mut query_str = String::new();
351 let col_num = self.schema.fields.len();
352 let mut next_param_id = 1;
353 let non_pk_col_indices = (0..col_num)
354 .filter(|idx| !self.pk_indices.contains(idx))
355 .collect::<Vec<usize>>();
356 let all_col_names = self
357 .schema
358 .fields
359 .iter()
360 .map(|f| format!("[{}]", f.name))
361 .collect::<Vec<_>>()
362 .join(",");
363 let all_source_col_names = self
364 .schema
365 .fields
366 .iter()
367 .map(|f| format!("[SOURCE].[{}]", f.name))
368 .collect::<Vec<_>>()
369 .join(",");
370 let pk_match = self
371 .pk_indices
372 .iter()
373 .map(|idx| {
374 format!(
375 "[SOURCE].[{}]=[TARGET].[{}]",
376 &self.schema[*idx].name, &self.schema[*idx].name
377 )
378 })
379 .collect::<Vec<_>>()
380 .join(" AND ");
381 let param_placeholders = |param_id: &mut usize| {
382 (0..col_num)
383 .map(|_| param_placeholder(param_id))
384 .collect::<Vec<_>>()
385 .join(",")
386 };
387 let set_all_source_col = non_pk_col_indices
388 .iter()
389 .map(|idx| {
390 format!(
391 "[{}]=[SOURCE].[{}]",
392 &self.schema[*idx].name, &self.schema[*idx].name
393 )
394 })
395 .collect::<Vec<_>>()
396 .join(",");
397 for op in &self.ops {
399 match op {
400 SqlOp::Insert(_) => {
401 write!(
402 &mut query_str,
403 "INSERT INTO {} ({}) VALUES ({});",
404 self.config.full_object_path(),
405 all_col_names,
406 param_placeholders(&mut next_param_id),
407 )
408 .unwrap();
409 }
410 SqlOp::Merge(_) => {
411 write!(
412 &mut query_str,
413 r#"MERGE {} WITH (HOLDLOCK) AS [TARGET]
414 USING (VALUES ({})) AS [SOURCE] ({})
415 ON {}
416 WHEN MATCHED THEN UPDATE SET {}
417 WHEN NOT MATCHED THEN INSERT ({}) VALUES ({});"#,
418 self.config.full_object_path(),
419 param_placeholders(&mut next_param_id),
420 all_col_names,
421 pk_match,
422 set_all_source_col,
423 all_col_names,
424 all_source_col_names,
425 )
426 .unwrap();
427 }
428 SqlOp::Delete(_) => {
429 write!(
430 &mut query_str,
431 r#"DELETE FROM {} WHERE {};"#,
432 self.config.full_object_path(),
433 self.pk_indices
434 .iter()
435 .map(|idx| {
436 let condition = format!(
437 "[{}]={}",
438 self.schema[*idx].name,
439 param_placeholder(&mut next_param_id)
440 );
441 condition
442 })
443 .collect::<Vec<_>>()
444 .join(" AND "),
445 )
446 .unwrap();
447 }
448 }
449 }
450
451 let mut query = Query::new(query_str);
452 for op in self.ops.drain(..) {
453 match op {
454 SqlOp::Insert(row) => {
455 bind_params(
456 &mut query,
457 row,
458 &self.schema,
459 &self.downstream_column_data_types,
460 0..col_num,
461 )?;
462 }
463 SqlOp::Merge(row) => {
464 bind_params(
465 &mut query,
466 row,
467 &self.schema,
468 &self.downstream_column_data_types,
469 0..col_num,
470 )?;
471 }
472 SqlOp::Delete(row) => {
473 bind_params(
474 &mut query,
475 row,
476 &self.schema,
477 &self.downstream_column_data_types,
478 self.pk_indices.iter().copied(),
479 )?;
480 }
481 }
482 }
483 query.execute(&mut self.sql_client.inner_client).await?;
484 Ok(())
485 }
486}
487
488#[async_trait]
489impl SinkWriter for SqlServerSinkWriter {
490 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
491 Ok(())
492 }
493
494 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
495 for (op, row) in chunk.rows() {
496 match op {
497 Op::Insert => {
498 if self.is_append_only {
499 self.insert_one(row).await?;
500 } else {
501 self.upsert_one(row).await?;
502 }
503 }
504 Op::UpdateInsert => {
505 debug_assert!(!self.is_append_only);
506 self.upsert_one(row).await?;
507 }
508 Op::Delete => {
509 debug_assert!(!self.is_append_only);
510 self.delete_one(row).await?;
511 }
512 Op::UpdateDelete => {}
513 }
514 }
515 Ok(())
516 }
517
518 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Self::CommitMetadata> {
519 if is_checkpoint {
520 self.flush().await?;
521 }
522 Ok(())
523 }
524}
525
526#[derive(Debug)]
527pub struct SqlServerClient {
528 pub inner_client: Client<tokio_util::compat::Compat<TcpStream>>,
529}
530
531impl SqlServerClient {
532 async fn new(msconfig: &SqlServerConfig) -> Result<Self> {
533 let mut config = Config::new();
534 config.host(&msconfig.host);
535 config.port(msconfig.port);
536 config.authentication(AuthMethod::sql_server(&msconfig.user, &msconfig.password));
537 config.database(&msconfig.database);
538 config.trust_cert();
539 Self::new_with_config(config).await
540 }
541
542 pub async fn new_with_config(mut config: Config) -> Result<Self> {
543 let tcp = TcpStream::connect(config.get_addr())
544 .await
545 .context("failed to connect to sql server")
546 .map_err(SinkError::SqlServer)?;
547 tcp.set_nodelay(true)
548 .context("failed to setting nodelay when connecting to sql server")
549 .map_err(SinkError::SqlServer)?;
550
551 let client = match Client::connect(config.clone(), tcp.compat_write()).await {
552 Ok(client) => client,
554 Err(tiberius::error::Error::Routing { host, port }) => {
556 config.host(&host);
557 config.port(port);
558 let tcp = TcpStream::connect(config.get_addr())
559 .await
560 .context("failed to connect to sql server after routing")
561 .map_err(SinkError::SqlServer)?;
562 tcp.set_nodelay(true)
563 .context(
564 "failed to setting nodelay when connecting to sql server after routing",
565 )
566 .map_err(SinkError::SqlServer)?;
567 Client::connect(config, tcp.compat_write()).await?
569 }
570 Err(e) => return Err(e.into()),
571 };
572
573 Ok(Self {
574 inner_client: client,
575 })
576 }
577}
578
579async fn query_sql_server_table_metadata(
580 sql_client: &mut SqlServerClient,
581 config: &SqlServerConfig,
582) -> Result<Vec<SqlServerColumnMetadata>> {
583 let mut sql_server_table_metadata = Vec::new();
584 let query_table_metadata_error = || {
585 SinkError::SqlServer(anyhow!(format!(
586 "SQL Server table {} metadata error",
587 config.full_object_path()
588 )))
589 };
590 static QUERY_TABLE_METADATA: &str = r#"
596SELECT
597 col.name AS ColumnName,
598 CAST(CASE WHEN pk_col.column_id IS NULL THEN 0 ELSE 1 END AS int) AS IsPk,
599 typ.name AS DataType
600FROM
601 sys.columns col
602JOIN
603 sys.types typ ON typ.user_type_id = col.user_type_id
604LEFT JOIN
605 (
606 SELECT ic.object_id, ic.column_id
607 FROM sys.indexes pk
608 JOIN sys.index_columns ic ON ic.object_id = pk.object_id AND ic.index_id = pk.index_id
609 WHERE pk.is_primary_key = 1
610 ) pk_col ON pk_col.object_id = col.object_id AND pk_col.column_id = col.column_id
611WHERE
612 col.object_id = OBJECT_ID(@P1)
613ORDER BY
614 col.column_id;"#;
615 let rows = sql_client
616 .inner_client
617 .query(QUERY_TABLE_METADATA, &[&config.full_object_path()])
618 .await?
619 .into_results()
620 .await?;
621 for row in rows.into_iter().flatten() {
622 let mut iter = row.into_iter();
623 let ColumnData::String(Some(col_name)) =
624 iter.next().ok_or_else(query_table_metadata_error)?
625 else {
626 return Err(query_table_metadata_error());
627 };
628 let ColumnData::I32(Some(col_is_pk)) =
629 iter.next().ok_or_else(query_table_metadata_error)?
630 else {
631 return Err(query_table_metadata_error());
632 };
633 let ColumnData::String(Some(data_type)) =
634 iter.next().ok_or_else(query_table_metadata_error)?
635 else {
636 return Err(query_table_metadata_error());
637 };
638 sql_server_table_metadata.push(SqlServerColumnMetadata {
639 name: normalize_sql_server_column_name(&col_name),
640 is_pk: col_is_pk != 0,
641 data_type: data_type.into_owned(),
642 });
643 }
644 Ok(sql_server_table_metadata)
645}
646
647async fn validate_sql_server_write_permission(
648 sql_client: &mut SqlServerClient,
649 config: &SqlServerConfig,
650 is_append_only: bool,
651) -> Result<()> {
652 let permission_query_error = || {
653 SinkError::SqlServer(anyhow!(format!(
654 "SQL Server table {} permission metadata error",
655 config.full_object_path()
656 )))
657 };
658 static QUERY_WRITE_PERMISSION: &str = r#"
659SELECT
660 CAST(HAS_PERMS_BY_NAME(@P1, 'OBJECT', 'INSERT') AS int) AS CanInsert,
661 CAST(HAS_PERMS_BY_NAME(@P1, 'OBJECT', 'UPDATE') AS int) AS CanUpdate,
662 CAST(HAS_PERMS_BY_NAME(@P1, 'OBJECT', 'DELETE') AS int) AS CanDelete;"#;
663 let rows = sql_client
664 .inner_client
665 .query(QUERY_WRITE_PERMISSION, &[&config.full_object_path()])
666 .await?
667 .into_results()
668 .await?;
669 let mut rows = rows.into_iter().flatten();
670 let row = rows.next().ok_or_else(permission_query_error)?;
671 let mut iter = row.into_iter();
672 let ColumnData::I32(can_insert) = iter.next().ok_or_else(permission_query_error)? else {
673 return Err(permission_query_error());
674 };
675 let ColumnData::I32(can_update) = iter.next().ok_or_else(permission_query_error)? else {
676 return Err(permission_query_error());
677 };
678 let ColumnData::I32(can_delete) = iter.next().ok_or_else(permission_query_error)? else {
679 return Err(permission_query_error());
680 };
681
682 let missing_permissions = missing_sql_server_write_permissions(
683 is_append_only,
684 permission_is_granted(can_insert),
685 permission_is_granted(can_update),
686 permission_is_granted(can_delete),
687 );
688 if missing_permissions.is_empty() {
689 return Ok(());
690 }
691
692 Err(SinkError::SqlServer(anyhow!(format!(
693 "SQL Server user {} lacks required write permission(s) {} on table {}",
694 config.user,
695 missing_permissions.join(", "),
696 config.full_object_path()
697 ))))
698}
699
700fn permission_is_granted(permission_value: Option<i32>) -> bool {
701 permission_value == Some(1)
702}
703
704fn missing_sql_server_write_permissions(
705 is_append_only: bool,
706 can_insert: bool,
707 can_update: bool,
708 can_delete: bool,
709) -> Vec<&'static str> {
710 let mut missing_permissions = vec![];
711 if !can_insert {
712 missing_permissions.push("INSERT");
713 }
714 if !is_append_only {
715 if !can_update {
716 missing_permissions.push("UPDATE");
717 }
718 if !can_delete {
719 missing_permissions.push("DELETE");
720 }
721 }
722 missing_permissions
723}
724
725async fn query_downstream_column_metadata(
726 sql_client: &mut SqlServerClient,
727 config: &SqlServerConfig,
728 schema: &Schema,
729) -> Result<Vec<SqlServerColumnMetadata>> {
730 let sql_server_table_metadata = query_sql_server_table_metadata(sql_client, config)
731 .await?
732 .into_iter()
733 .map(|metadata| (metadata.name.clone(), metadata))
734 .collect::<HashMap<_, _>>();
735 schema
736 .fields()
737 .iter()
738 .map(|col| {
739 sql_server_table_metadata
740 .get(&normalize_sql_server_column_name(&col.name))
741 .map(|metadata| SqlServerColumnMetadata {
742 name: metadata.name.clone(),
743 is_pk: metadata.is_pk,
744 data_type: metadata.data_type.clone(),
745 })
746 .ok_or_else(|| {
747 SinkError::SqlServer(anyhow!(format!(
748 "column {} not found in the downstream SQL Server table {}",
749 col.name,
750 config.full_object_path()
751 )))
752 })
753 })
754 .collect()
755}
756
757fn param_placeholder(param_id: &mut usize) -> String {
758 let placeholder = format!("@P{}", *param_id);
759 *param_id += 1;
760 placeholder
761}
762
763fn bind_params(
764 query: &mut Query<'_>,
765 row: impl Row,
766 schema: &Schema,
767 downstream_column_data_types: &[String],
768 col_indices: impl Iterator<Item = usize>,
769) -> Result<()> {
770 use risingwave_common::types::ScalarRefImpl;
771 for col_idx in col_indices {
772 match row.datum_at(col_idx) {
773 Some(data_ref) => match data_ref {
774 ScalarRefImpl::Int16(v) => query.bind(v),
775 ScalarRefImpl::Int32(v) => query.bind(v),
776 ScalarRefImpl::Int64(v) => query.bind(v),
777 ScalarRefImpl::Float32(v) => query.bind(v.into_inner()),
778 ScalarRefImpl::Float64(v) => query.bind(v.into_inner()),
779 ScalarRefImpl::Utf8(v) => query.bind(v.to_owned()),
780 ScalarRefImpl::Bool(v) => query.bind(v),
781 ScalarRefImpl::Decimal(v) => match v {
782 Decimal::Normalized(d) => {
783 query.bind(decimal_to_sql(&d));
784 }
785 Decimal::NaN | Decimal::PositiveInf | Decimal::NegativeInf => {
786 tracing::warn!(
787 "Inf, -Inf, Nan in RisingWave decimal is converted into SQL Server null!"
788 );
789 query.bind(None as Option<Numeric>);
790 }
791 },
792 ScalarRefImpl::Date(v) => query.bind(v.0),
793 ScalarRefImpl::Timestamp(v) => query.bind(v.0),
794 ScalarRefImpl::Timestamptz(v) => {
795 let downstream_data_type = &downstream_column_data_types[col_idx];
796 match downstream_data_type.as_str() {
797 "bigint" | "int" | "smallint" | "tinyint" => {
798 query.bind(v.timestamp_micros());
799 }
800 "datetimeoffset" => {
801 query.bind(v.to_datetime_utc().fixed_offset());
802 }
803 "datetime" | "datetime2" | "smalldatetime" => {
804 query.bind(v.to_datetime_utc().naive_utc());
805 }
806 _ => {
807 return Err(unexpected_downstream_timestamptz_type(
808 downstream_data_type,
809 ));
810 }
811 };
812 }
813 ScalarRefImpl::Time(v) => query.bind(v.0),
814 ScalarRefImpl::Bytea(v) => query.bind(v.to_vec()),
815 ScalarRefImpl::Interval(_) => return Err(data_type_not_supported("Interval")),
816 ScalarRefImpl::Jsonb(_) => return Err(data_type_not_supported("Jsonb")),
817 ScalarRefImpl::Struct(_) => return Err(data_type_not_supported("Struct")),
818 ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")),
819 ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")),
820 ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")),
821 ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")),
822 ScalarRefImpl::Vector(_) => return Err(data_type_not_supported("Vector")),
823 },
824 None => match schema[col_idx].data_type {
825 DataType::Boolean => {
826 query.bind(None as Option<bool>);
827 }
828 DataType::Int16 => {
829 query.bind(None as Option<i16>);
830 }
831 DataType::Int32 => {
832 query.bind(None as Option<i32>);
833 }
834 DataType::Int64 => {
835 query.bind(None as Option<i64>);
836 }
837 DataType::Float32 => {
838 query.bind(None as Option<f32>);
839 }
840 DataType::Float64 => {
841 query.bind(None as Option<f64>);
842 }
843 DataType::Decimal => {
844 query.bind(None as Option<Numeric>);
845 }
846 DataType::Date => {
847 query.bind(None as Option<chrono::NaiveDate>);
848 }
849 DataType::Time => {
850 query.bind(None as Option<chrono::NaiveTime>);
851 }
852 DataType::Timestamp => {
853 query.bind(None as Option<chrono::NaiveDateTime>);
854 }
855 DataType::Timestamptz => {
856 let downstream_data_type = &downstream_column_data_types[col_idx];
857 match downstream_data_type.as_str() {
858 "bigint" | "int" | "smallint" | "tinyint" => {
859 query.bind(None as Option<i64>);
860 }
861 "datetimeoffset" => {
862 query.bind(None as Option<chrono::DateTime<chrono::FixedOffset>>);
863 }
864 "datetime" | "datetime2" | "smalldatetime" => {
865 query.bind(None as Option<chrono::NaiveDateTime>);
866 }
867 _ => {
868 return Err(unexpected_downstream_timestamptz_type(
869 downstream_data_type,
870 ));
871 }
872 };
873 }
874 DataType::Varchar => {
875 query.bind(None as Option<String>);
876 }
877 DataType::Bytea => {
878 query.bind(None as Option<Vec<u8>>);
879 }
880 DataType::Interval => return Err(data_type_not_supported("Interval")),
881 DataType::Struct(_) => return Err(data_type_not_supported("Struct")),
882 DataType::List(_) => return Err(data_type_not_supported("List")),
883 DataType::Jsonb => return Err(data_type_not_supported("Jsonb")),
884 DataType::Serial => return Err(data_type_not_supported("Serial")),
885 DataType::Int256 => return Err(data_type_not_supported("Int256")),
886 DataType::Map(_) => return Err(data_type_not_supported("Map")),
887 DataType::Vector(_) => return Err(data_type_not_supported("Vector")),
888 },
889 };
890 }
891 Ok(())
892}
893
894fn data_type_not_supported(data_type_name: &str) -> SinkError {
895 SinkError::SqlServer(anyhow!(format!(
896 "{data_type_name} is not supported in SQL Server"
897 )))
898}
899
900fn unexpected_downstream_timestamptz_type(sql_server_data_type: &str) -> SinkError {
901 SinkError::SqlServer(anyhow!(format!(
902 "unexpected downstream SQL Server type {sql_server_data_type} for Timestamptz"
903 )))
904}
905
906fn check_data_type_compatibility(data_type: &DataType) -> Result<()> {
907 match data_type {
908 DataType::Boolean
909 | DataType::Int16
910 | DataType::Int32
911 | DataType::Int64
912 | DataType::Float32
913 | DataType::Float64
914 | DataType::Decimal
915 | DataType::Date
916 | DataType::Varchar
917 | DataType::Time
918 | DataType::Timestamp
919 | DataType::Timestamptz
920 | DataType::Bytea => Ok(()),
921 DataType::Interval => Err(data_type_not_supported("Interval")),
922 DataType::Struct(_) => Err(data_type_not_supported("Struct")),
923 DataType::List(_) => Err(data_type_not_supported("List")),
924 DataType::Jsonb => Err(data_type_not_supported("Jsonb")),
925 DataType::Serial => Err(data_type_not_supported("Serial")),
926 DataType::Int256 => Err(data_type_not_supported("Int256")),
927 DataType::Map(_) => Err(data_type_not_supported("Map")),
928 DataType::Vector(_) => Err(data_type_not_supported("Vector")),
929 }
930}
931
932fn normalize_sql_server_column_name(column_name: &str) -> String {
933 column_name.to_lowercase()
936}
937
938fn validate_data_type_compatibility(
939 column_name: &str,
940 rw_data_type: &DataType,
941 sql_server_data_type: &str,
942) -> Result<()> {
943 if sql_server_data_type_is_compatible(rw_data_type, sql_server_data_type) {
944 return Ok(());
945 }
946
947 Err(SinkError::SqlServer(anyhow!(format!(
948 "column {} data type {:?} is incompatible with downstream SQL Server type {}",
949 column_name, rw_data_type, sql_server_data_type
950 ))))
951}
952
953fn sql_server_data_type_is_compatible(rw_data_type: &DataType, sql_server_data_type: &str) -> bool {
954 match rw_data_type {
955 DataType::Boolean => sql_server_data_type == "bit",
956 DataType::Int16 => matches!(sql_server_data_type, "smallint" | "int" | "bigint"),
957 DataType::Int32 => matches!(sql_server_data_type, "int" | "bigint"),
958 DataType::Int64 => sql_server_data_type == "bigint",
959 DataType::Float32 => matches!(sql_server_data_type, "real" | "float"),
960 DataType::Float64 => sql_server_data_type == "float",
961 DataType::Decimal => matches!(sql_server_data_type, "decimal" | "numeric"),
962 DataType::Date => sql_server_data_type == "date",
963 DataType::Varchar => matches!(
964 sql_server_data_type,
965 "char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext"
966 ),
967 DataType::Time => sql_server_data_type == "time",
968 DataType::Timestamp => {
969 matches!(
970 sql_server_data_type,
971 "datetime" | "datetime2" | "smalldatetime"
972 )
973 }
974 DataType::Timestamptz => matches!(
975 sql_server_data_type,
976 "datetimeoffset"
977 | "datetime"
978 | "datetime2"
979 | "smalldatetime"
980 | "bigint"
981 | "int"
982 | "smallint"
983 | "tinyint"
984 ),
985 DataType::Bytea => matches!(sql_server_data_type, "binary" | "varbinary" | "image"),
986 DataType::Interval
987 | DataType::Struct(_)
988 | DataType::List(_)
989 | DataType::Jsonb
990 | DataType::Serial
991 | DataType::Int256
992 | DataType::Map(_)
993 | DataType::Vector(_) => false,
994 }
995}
996
997fn decimal_to_sql(decimal: &rust_decimal::Decimal) -> Numeric {
999 let unpacked = decimal.unpack();
1000
1001 let mut value = (((unpacked.hi as u128) << 64)
1002 + ((unpacked.mid as u128) << 32)
1003 + unpacked.lo as u128) as i128;
1004
1005 if decimal.is_sign_negative() {
1006 value = -value;
1007 }
1008
1009 Numeric::new_with_scale(value, decimal.scale() as u8)
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015
1016 #[test]
1017 fn test_normalize_sql_server_column_name() {
1018 assert_eq!(normalize_sql_server_column_name("EventDate"), "eventdate");
1019 }
1020
1021 #[test]
1022 fn test_sql_server_data_type_compatibility() {
1023 assert!(sql_server_data_type_is_compatible(
1024 &DataType::Int16,
1025 "smallint"
1026 ));
1027 assert!(sql_server_data_type_is_compatible(&DataType::Int16, "int"));
1028 assert!(!sql_server_data_type_is_compatible(
1029 &DataType::Int32,
1030 "smallint"
1031 ));
1032
1033 assert!(sql_server_data_type_is_compatible(
1034 &DataType::Timestamp,
1035 "datetime2"
1036 ));
1037 assert!(sql_server_data_type_is_compatible(
1038 &DataType::Timestamptz,
1039 "datetimeoffset"
1040 ));
1041 assert!(sql_server_data_type_is_compatible(
1042 &DataType::Timestamptz,
1043 "datetime2"
1044 ));
1045 assert!(sql_server_data_type_is_compatible(
1046 &DataType::Timestamptz,
1047 "bigint"
1048 ));
1049 assert!(sql_server_data_type_is_compatible(
1050 &DataType::Timestamptz,
1051 "int"
1052 ));
1053 assert!(!sql_server_data_type_is_compatible(
1054 &DataType::Timestamp,
1055 "datetimeoffset"
1056 ));
1057
1058 assert!(sql_server_data_type_is_compatible(
1059 &DataType::Varchar,
1060 "nvarchar"
1061 ));
1062 assert!(!sql_server_data_type_is_compatible(
1063 &DataType::Varchar,
1064 "uniqueidentifier"
1065 ));
1066 }
1067
1068 #[test]
1069 fn test_missing_sql_server_write_permissions() {
1070 assert_eq!(
1071 missing_sql_server_write_permissions(true, false, false, false),
1072 vec!["INSERT"]
1073 );
1074 assert!(missing_sql_server_write_permissions(true, true, false, false).is_empty());
1075 assert_eq!(
1076 missing_sql_server_write_permissions(false, false, false, false),
1077 vec!["INSERT", "UPDATE", "DELETE"]
1078 );
1079 assert_eq!(
1080 missing_sql_server_write_permissions(false, true, false, true),
1081 vec!["UPDATE"]
1082 );
1083 assert!(missing_sql_server_write_permissions(false, true, true, true).is_empty());
1084 }
1085}