1use std::collections::{BTreeMap, HashMap};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use phf::{Set, phf_set};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::{OwnedRow, Row};
23use risingwave_common::types::{DataType, Decimal};
24use serde::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use simd_json::prelude::ArrayTrait;
27use tiberius::numeric::Numeric;
28use tiberius::{AuthMethod, Client, ColumnData, Config, Query};
29use tokio::net::TcpStream;
30use tokio_util::compat::TokioAsyncWriteCompatExt;
31use with_options::WithOptions;
32
33use super::{
34 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkWriterMetrics,
35};
36use crate::enforce_secret::EnforceSecret;
37use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
38use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
39
40pub const SQLSERVER_SINK: &str = "sqlserver";
41
42fn default_max_batch_rows() -> usize {
43 1024
44}
45
46#[serde_as]
47#[derive(Clone, Debug, Deserialize, WithOptions)]
48pub struct SqlServerConfig {
49 #[serde(rename = "sqlserver.host")]
50 pub host: String,
51 #[serde(rename = "sqlserver.port")]
52 #[serde_as(as = "DisplayFromStr")]
53 pub port: u16,
54 #[serde(rename = "sqlserver.user")]
55 pub user: String,
56 #[serde(rename = "sqlserver.password")]
57 pub password: String,
58 #[serde(rename = "sqlserver.database")]
59 pub database: String,
60 #[serde(rename = "sqlserver.schema", default = "sql_server_default_schema")]
61 pub schema: String,
62 #[serde(rename = "sqlserver.table")]
63 pub table: String,
64 #[serde(
65 rename = "sqlserver.max_batch_rows",
66 default = "default_max_batch_rows"
67 )]
68 #[serde_as(as = "DisplayFromStr")]
69 pub max_batch_rows: usize,
70 pub r#type: String, }
72
73pub fn sql_server_default_schema() -> String {
74 "dbo".to_owned()
75}
76
77impl SqlServerConfig {
78 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
79 let config =
80 serde_json::from_value::<SqlServerConfig>(serde_json::to_value(properties).unwrap())
81 .map_err(|e| SinkError::Config(anyhow!(e)))?;
82 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
83 return Err(SinkError::Config(anyhow!(
84 "`{}` must be {}, or {}",
85 SINK_TYPE_OPTION,
86 SINK_TYPE_APPEND_ONLY,
87 SINK_TYPE_UPSERT
88 )));
89 }
90 Ok(config)
91 }
92
93 pub fn full_object_path(&self) -> String {
94 format!("[{}].[{}].[{}]", self.database, self.schema, self.table)
95 }
96}
97
98impl EnforceSecret for SqlServerConfig {
99 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
100 "sqlserver.password"
101 };
102}
103#[derive(Debug)]
104pub struct SqlServerSink {
105 pub config: SqlServerConfig,
106 schema: Schema,
107 pk_indices: Vec<usize>,
108 is_append_only: bool,
109}
110
111impl EnforceSecret for SqlServerSink {
112 fn enforce_secret<'a>(
113 prop_iter: impl Iterator<Item = &'a str>,
114 ) -> crate::sink::ConnectorResult<()> {
115 for prop in prop_iter {
116 SqlServerConfig::enforce_one(prop)?;
117 }
118 Ok(())
119 }
120}
121impl SqlServerSink {
122 pub fn new(
123 mut config: SqlServerConfig,
124 schema: Schema,
125 pk_indices: Vec<usize>,
126 is_append_only: bool,
127 ) -> Result<Self> {
128 const TIBERIUS_PARAM_MAX: usize = 2000;
130 let params_per_op = schema.fields().len();
131 let tiberius_max_batch_rows = if params_per_op == 0 {
132 config.max_batch_rows
133 } else {
134 ((TIBERIUS_PARAM_MAX as f64 / params_per_op as f64).floor()) as usize
135 };
136 if tiberius_max_batch_rows == 0 {
137 return Err(SinkError::SqlServer(anyhow!(format!(
138 "too many column {}",
139 params_per_op
140 ))));
141 }
142 config.max_batch_rows = std::cmp::min(config.max_batch_rows, tiberius_max_batch_rows);
143 Ok(Self {
144 config,
145 schema,
146 pk_indices,
147 is_append_only,
148 })
149 }
150}
151
152impl TryFrom<SinkParam> for SqlServerSink {
153 type Error = SinkError;
154
155 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
156 let schema = param.schema();
157 let pk_indices = param.downstream_pk_or_empty();
158 let config = SqlServerConfig::from_btreemap(param.properties)?;
159 SqlServerSink::new(config, schema, pk_indices, param.sink_type.is_append_only())
160 }
161}
162
163impl Sink for SqlServerSink {
164 type LogSinker = LogSinkerOf<SqlServerSinkWriter>;
165
166 const SINK_NAME: &'static str = SQLSERVER_SINK;
167
168 async fn validate(&self) -> Result<()> {
169 risingwave_common::license::Feature::SqlServerSink
170 .check_available()
171 .map_err(|e| anyhow::anyhow!(e))?;
172
173 if !self.is_append_only && self.pk_indices.is_empty() {
174 return Err(SinkError::Config(anyhow!(
175 "Primary key not defined for upsert SQL Server sink (please define in `primary_key` field)"
176 )));
177 }
178
179 for f in self.schema.fields() {
180 check_data_type_compatibility(&f.data_type)?;
181 }
182
183 let mut sql_server_table_metadata = HashMap::new();
185 let mut sql_client = SqlServerClient::new(&self.config).await?;
186 let query_table_metadata_error = || {
187 SinkError::SqlServer(anyhow!(format!(
188 "SQL Server table {} metadata error",
189 self.config.full_object_path()
190 )))
191 };
192 static QUERY_TABLE_METADATA: &str = r#"
193SELECT
194 col.name AS ColumnName,
195 pk.index_id AS PkIndex
196FROM
197 sys.columns col
198LEFT JOIN
199 sys.index_columns ic ON ic.object_id = col.object_id AND ic.column_id = col.column_id
200LEFT JOIN
201 sys.indexes pk ON pk.object_id = col.object_id AND pk.index_id = ic.index_id AND pk.is_primary_key = 1
202WHERE
203 col.object_id = OBJECT_ID(@P1)
204ORDER BY
205 col.column_id;"#;
206 let rows = sql_client
207 .inner_client
208 .query(QUERY_TABLE_METADATA, &[&self.config.full_object_path()])
209 .await?
210 .into_results()
211 .await?;
212 for row in rows.into_iter().flatten() {
213 let mut iter = row.into_iter();
214 let ColumnData::String(Some(col_name)) =
215 iter.next().ok_or_else(query_table_metadata_error)?
216 else {
217 return Err(query_table_metadata_error());
218 };
219 let ColumnData::I32(col_pk_index) =
220 iter.next().ok_or_else(query_table_metadata_error)?
221 else {
222 return Err(query_table_metadata_error());
223 };
224 sql_server_table_metadata.insert(col_name.into_owned(), col_pk_index.is_some());
225 }
226
227 for (idx, col) in self.schema.fields().iter().enumerate() {
229 let rw_is_pk = self.pk_indices.contains(&idx);
230 match sql_server_table_metadata.get(&col.name) {
231 None => {
232 return Err(SinkError::SqlServer(anyhow!(format!(
233 "column {} not found in the downstream SQL Server table {}",
234 col.name,
235 self.config.full_object_path()
236 ))));
237 }
238 Some(sql_server_is_pk) => {
239 if self.is_append_only {
240 continue;
241 }
242 if rw_is_pk && !*sql_server_is_pk {
243 return Err(SinkError::SqlServer(anyhow!(format!(
244 "column {} specified in primary_key mismatches with the downstream SQL Server table {} PK",
245 col.name,
246 self.config.full_object_path(),
247 ))));
248 }
249 if !rw_is_pk && *sql_server_is_pk {
250 return Err(SinkError::SqlServer(anyhow!(format!(
251 "column {} unspecified in primary_key mismatches with the downstream SQL Server table {} PK",
252 col.name,
253 self.config.full_object_path(),
254 ))));
255 }
256 }
257 }
258 }
259
260 if !self.is_append_only {
261 let sql_server_pk_count = sql_server_table_metadata
262 .values()
263 .filter(|is_pk| **is_pk)
264 .count();
265 if sql_server_pk_count != self.pk_indices.len() {
266 return Err(SinkError::SqlServer(anyhow!(format!(
267 "primary key does not match between RisingWave sink ({}) and SQL Server table {} ({})",
268 self.pk_indices.len(),
269 self.config.full_object_path(),
270 sql_server_pk_count,
271 ))));
272 }
273 }
274
275 Ok(())
276 }
277
278 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
279 Ok(SqlServerSinkWriter::new(
280 self.config.clone(),
281 self.schema.clone(),
282 self.pk_indices.clone(),
283 self.is_append_only,
284 )
285 .await?
286 .into_log_sinker(SinkWriterMetrics::new(&writer_param)))
287 }
288}
289
290enum SqlOp {
291 Insert(OwnedRow),
292 Merge(OwnedRow),
293 Delete(OwnedRow),
294}
295
296pub struct SqlServerSinkWriter {
297 config: SqlServerConfig,
298 schema: Schema,
299 pk_indices: Vec<usize>,
300 is_append_only: bool,
301 sql_client: SqlServerClient,
302 ops: Vec<SqlOp>,
303}
304
305impl SqlServerSinkWriter {
306 async fn new(
307 config: SqlServerConfig,
308 schema: Schema,
309 pk_indices: Vec<usize>,
310 is_append_only: bool,
311 ) -> Result<Self> {
312 let sql_client = SqlServerClient::new(&config).await?;
313 let writer = Self {
314 config,
315 schema,
316 pk_indices,
317 is_append_only,
318 sql_client,
319 ops: vec![],
320 };
321 Ok(writer)
322 }
323
324 async fn delete_one(&mut self, row: RowRef<'_>) -> Result<()> {
325 if self.ops.len() + 1 >= self.config.max_batch_rows {
326 self.flush().await?;
327 }
328 self.ops.push(SqlOp::Delete(row.into_owned_row()));
329 Ok(())
330 }
331
332 async fn upsert_one(&mut self, row: RowRef<'_>) -> Result<()> {
333 if self.ops.len() + 1 >= self.config.max_batch_rows {
334 self.flush().await?;
335 }
336 self.ops.push(SqlOp::Merge(row.into_owned_row()));
337 Ok(())
338 }
339
340 async fn insert_one(&mut self, row: RowRef<'_>) -> Result<()> {
341 if self.ops.len() + 1 >= self.config.max_batch_rows {
342 self.flush().await?;
343 }
344 self.ops.push(SqlOp::Insert(row.into_owned_row()));
345 Ok(())
346 }
347
348 async fn flush(&mut self) -> Result<()> {
349 use std::fmt::Write;
350 if self.ops.is_empty() {
351 return Ok(());
352 }
353 let mut query_str = String::new();
354 let col_num = self.schema.fields.len();
355 let mut next_param_id = 1;
356 let non_pk_col_indices = (0..col_num)
357 .filter(|idx| !self.pk_indices.contains(idx))
358 .collect::<Vec<usize>>();
359 let all_col_names = self
360 .schema
361 .fields
362 .iter()
363 .map(|f| format!("[{}]", f.name))
364 .collect::<Vec<_>>()
365 .join(",");
366 let all_source_col_names = self
367 .schema
368 .fields
369 .iter()
370 .map(|f| format!("[SOURCE].[{}]", f.name))
371 .collect::<Vec<_>>()
372 .join(",");
373 let pk_match = self
374 .pk_indices
375 .iter()
376 .map(|idx| {
377 format!(
378 "[SOURCE].[{}]=[TARGET].[{}]",
379 &self.schema[*idx].name, &self.schema[*idx].name
380 )
381 })
382 .collect::<Vec<_>>()
383 .join(" AND ");
384 let param_placeholders = |param_id: &mut usize| {
385 let params = (*param_id..(*param_id + col_num))
386 .map(|i| format!("@P{}", i))
387 .collect::<Vec<_>>()
388 .join(",");
389 *param_id += col_num;
390 params
391 };
392 let set_all_source_col = non_pk_col_indices
393 .iter()
394 .map(|idx| {
395 format!(
396 "[{}]=[SOURCE].[{}]",
397 &self.schema[*idx].name, &self.schema[*idx].name
398 )
399 })
400 .collect::<Vec<_>>()
401 .join(",");
402 for op in &self.ops {
404 match op {
405 SqlOp::Insert(_) => {
406 write!(
407 &mut query_str,
408 "INSERT INTO {} ({}) VALUES ({});",
409 self.config.full_object_path(),
410 all_col_names,
411 param_placeholders(&mut next_param_id),
412 )
413 .unwrap();
414 }
415 SqlOp::Merge(_) => {
416 write!(
417 &mut query_str,
418 r#"MERGE {} WITH (HOLDLOCK) AS [TARGET]
419 USING (VALUES ({})) AS [SOURCE] ({})
420 ON {}
421 WHEN MATCHED THEN UPDATE SET {}
422 WHEN NOT MATCHED THEN INSERT ({}) VALUES ({});"#,
423 self.config.full_object_path(),
424 param_placeholders(&mut next_param_id),
425 all_col_names,
426 pk_match,
427 set_all_source_col,
428 all_col_names,
429 all_source_col_names,
430 )
431 .unwrap();
432 }
433 SqlOp::Delete(_) => {
434 write!(
435 &mut query_str,
436 r#"DELETE FROM {} WHERE {};"#,
437 self.config.full_object_path(),
438 self.pk_indices
439 .iter()
440 .map(|idx| {
441 let condition =
442 format!("[{}]=@P{}", self.schema[*idx].name, next_param_id);
443 next_param_id += 1;
444 condition
445 })
446 .collect::<Vec<_>>()
447 .join(" AND "),
448 )
449 .unwrap();
450 }
451 }
452 }
453
454 let mut query = Query::new(query_str);
455 for op in self.ops.drain(..) {
456 match op {
457 SqlOp::Insert(row) => {
458 bind_params(&mut query, row, &self.schema, 0..col_num)?;
459 }
460 SqlOp::Merge(row) => {
461 bind_params(&mut query, row, &self.schema, 0..col_num)?;
462 }
463 SqlOp::Delete(row) => {
464 bind_params(
465 &mut query,
466 row,
467 &self.schema,
468 self.pk_indices.iter().copied(),
469 )?;
470 }
471 }
472 }
473 query.execute(&mut self.sql_client.inner_client).await?;
474 Ok(())
475 }
476}
477
478#[async_trait]
479impl SinkWriter for SqlServerSinkWriter {
480 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
481 Ok(())
482 }
483
484 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
485 for (op, row) in chunk.rows() {
486 match op {
487 Op::Insert => {
488 if self.is_append_only {
489 self.insert_one(row).await?;
490 } else {
491 self.upsert_one(row).await?;
492 }
493 }
494 Op::UpdateInsert => {
495 debug_assert!(!self.is_append_only);
496 self.upsert_one(row).await?;
497 }
498 Op::Delete => {
499 debug_assert!(!self.is_append_only);
500 self.delete_one(row).await?;
501 }
502 Op::UpdateDelete => {}
503 }
504 }
505 Ok(())
506 }
507
508 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Self::CommitMetadata> {
509 if is_checkpoint {
510 self.flush().await?;
511 }
512 Ok(())
513 }
514}
515
516#[derive(Debug)]
517pub struct SqlServerClient {
518 pub inner_client: Client<tokio_util::compat::Compat<TcpStream>>,
519}
520
521impl SqlServerClient {
522 async fn new(msconfig: &SqlServerConfig) -> Result<Self> {
523 let mut config = Config::new();
524 config.host(&msconfig.host);
525 config.port(msconfig.port);
526 config.authentication(AuthMethod::sql_server(&msconfig.user, &msconfig.password));
527 config.database(&msconfig.database);
528 config.trust_cert();
529 Self::new_with_config(config).await
530 }
531
532 pub async fn new_with_config(mut config: Config) -> Result<Self> {
533 let tcp = TcpStream::connect(config.get_addr())
534 .await
535 .context("failed to connect to sql server")
536 .map_err(SinkError::SqlServer)?;
537 tcp.set_nodelay(true)
538 .context("failed to setting nodelay when connecting to sql server")
539 .map_err(SinkError::SqlServer)?;
540
541 let client = match Client::connect(config.clone(), tcp.compat_write()).await {
542 Ok(client) => client,
544 Err(tiberius::error::Error::Routing { host, port }) => {
546 config.host(&host);
547 config.port(port);
548 let tcp = TcpStream::connect(config.get_addr())
549 .await
550 .context("failed to connect to sql server after routing")
551 .map_err(SinkError::SqlServer)?;
552 tcp.set_nodelay(true)
553 .context(
554 "failed to setting nodelay when connecting to sql server after routing",
555 )
556 .map_err(SinkError::SqlServer)?;
557 Client::connect(config, tcp.compat_write()).await?
559 }
560 Err(e) => return Err(e.into()),
561 };
562
563 Ok(Self {
564 inner_client: client,
565 })
566 }
567}
568
569fn bind_params(
570 query: &mut Query<'_>,
571 row: impl Row,
572 schema: &Schema,
573 col_indices: impl Iterator<Item = usize>,
574) -> Result<()> {
575 use risingwave_common::types::ScalarRefImpl;
576 for col_idx in col_indices {
577 match row.datum_at(col_idx) {
578 Some(data_ref) => match data_ref {
579 ScalarRefImpl::Int16(v) => query.bind(v),
580 ScalarRefImpl::Int32(v) => query.bind(v),
581 ScalarRefImpl::Int64(v) => query.bind(v),
582 ScalarRefImpl::Float32(v) => query.bind(v.into_inner()),
583 ScalarRefImpl::Float64(v) => query.bind(v.into_inner()),
584 ScalarRefImpl::Utf8(v) => query.bind(v.to_owned()),
585 ScalarRefImpl::Bool(v) => query.bind(v),
586 ScalarRefImpl::Decimal(v) => match v {
587 Decimal::Normalized(d) => {
588 query.bind(decimal_to_sql(&d));
589 }
590 Decimal::NaN | Decimal::PositiveInf | Decimal::NegativeInf => {
591 tracing::warn!(
592 "Inf, -Inf, Nan in RisingWave decimal is converted into SQL Server null!"
593 );
594 query.bind(None as Option<Numeric>);
595 }
596 },
597 ScalarRefImpl::Date(v) => query.bind(v.0),
598 ScalarRefImpl::Timestamp(v) => query.bind(v.0),
599 ScalarRefImpl::Timestamptz(v) => query.bind(v.timestamp_micros()),
600 ScalarRefImpl::Time(v) => query.bind(v.0),
601 ScalarRefImpl::Bytea(v) => query.bind(v.to_vec()),
602 ScalarRefImpl::Interval(_) => return Err(data_type_not_supported("Interval")),
603 ScalarRefImpl::Jsonb(_) => return Err(data_type_not_supported("Jsonb")),
604 ScalarRefImpl::Struct(_) => return Err(data_type_not_supported("Struct")),
605 ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")),
606 ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")),
607 ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")),
608 ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")),
609 ScalarRefImpl::Vector(_) => return Err(data_type_not_supported("Vector")),
610 },
611 None => match schema[col_idx].data_type {
612 DataType::Boolean => {
613 query.bind(None as Option<bool>);
614 }
615 DataType::Int16 => {
616 query.bind(None as Option<i16>);
617 }
618 DataType::Int32 => {
619 query.bind(None as Option<i32>);
620 }
621 DataType::Int64 => {
622 query.bind(None as Option<i64>);
623 }
624 DataType::Float32 => {
625 query.bind(None as Option<f32>);
626 }
627 DataType::Float64 => {
628 query.bind(None as Option<f64>);
629 }
630 DataType::Decimal => {
631 query.bind(None as Option<Numeric>);
632 }
633 DataType::Date => {
634 query.bind(None as Option<chrono::NaiveDate>);
635 }
636 DataType::Time => {
637 query.bind(None as Option<chrono::NaiveTime>);
638 }
639 DataType::Timestamp => {
640 query.bind(None as Option<chrono::NaiveDateTime>);
641 }
642 DataType::Timestamptz => {
643 query.bind(None as Option<i64>);
644 }
645 DataType::Varchar => {
646 query.bind(None as Option<String>);
647 }
648 DataType::Bytea => {
649 query.bind(None as Option<Vec<u8>>);
650 }
651 DataType::Interval => return Err(data_type_not_supported("Interval")),
652 DataType::Struct(_) => return Err(data_type_not_supported("Struct")),
653 DataType::List(_) => return Err(data_type_not_supported("List")),
654 DataType::Jsonb => return Err(data_type_not_supported("Jsonb")),
655 DataType::Serial => return Err(data_type_not_supported("Serial")),
656 DataType::Int256 => return Err(data_type_not_supported("Int256")),
657 DataType::Map(_) => return Err(data_type_not_supported("Map")),
658 DataType::Vector(_) => return Err(data_type_not_supported("Vector")),
659 },
660 };
661 }
662 Ok(())
663}
664
665fn data_type_not_supported(data_type_name: &str) -> SinkError {
666 SinkError::SqlServer(anyhow!(format!(
667 "{data_type_name} is not supported in SQL Server"
668 )))
669}
670
671fn check_data_type_compatibility(data_type: &DataType) -> Result<()> {
672 match data_type {
673 DataType::Boolean
674 | DataType::Int16
675 | DataType::Int32
676 | DataType::Int64
677 | DataType::Float32
678 | DataType::Float64
679 | DataType::Decimal
680 | DataType::Date
681 | DataType::Varchar
682 | DataType::Time
683 | DataType::Timestamp
684 | DataType::Timestamptz
685 | DataType::Bytea => Ok(()),
686 DataType::Interval => Err(data_type_not_supported("Interval")),
687 DataType::Struct(_) => Err(data_type_not_supported("Struct")),
688 DataType::List(_) => Err(data_type_not_supported("List")),
689 DataType::Jsonb => Err(data_type_not_supported("Jsonb")),
690 DataType::Serial => Err(data_type_not_supported("Serial")),
691 DataType::Int256 => Err(data_type_not_supported("Int256")),
692 DataType::Map(_) => Err(data_type_not_supported("Map")),
693 DataType::Vector(_) => Err(data_type_not_supported("Vector")),
694 }
695}
696
697fn decimal_to_sql(decimal: &rust_decimal::Decimal) -> Numeric {
699 let unpacked = decimal.unpack();
700
701 let mut value = (((unpacked.hi as u128) << 64)
702 + ((unpacked.mid as u128) << 32)
703 + unpacked.lo as u128) as i128;
704
705 if decimal.is_sign_negative() {
706 value = -value;
707 }
708
709 Numeric::new_with_scale(value, decimal.scale() as u8)
710}