1use std::collections::{BTreeMap, HashMap};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use phf::{Set, phf_set};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::{OwnedRow, Row};
23use risingwave_common::types::{DataType, Decimal};
24use serde_derive::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use simd_json::prelude::ArrayTrait;
27use tiberius::numeric::Numeric;
28use tiberius::{AuthMethod, Client, ColumnData, Config, Query};
29use tokio::net::TcpStream;
30use tokio_util::compat::TokioAsyncWriteCompatExt;
31use with_options::WithOptions;
32
33use super::{
34 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkWriterMetrics,
35};
36use crate::enforce_secret::EnforceSecret;
37use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
38use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
39
40pub const SQLSERVER_SINK: &str = "sqlserver";
41
42fn default_max_batch_rows() -> usize {
43 1024
44}
45
46#[serde_as]
47#[derive(Clone, Debug, Deserialize, WithOptions)]
48pub struct SqlServerConfig {
49 #[serde(rename = "sqlserver.host")]
50 pub host: String,
51 #[serde(rename = "sqlserver.port")]
52 #[serde_as(as = "DisplayFromStr")]
53 pub port: u16,
54 #[serde(rename = "sqlserver.user")]
55 pub user: String,
56 #[serde(rename = "sqlserver.password")]
57 pub password: String,
58 #[serde(rename = "sqlserver.database")]
59 pub database: String,
60 #[serde(rename = "sqlserver.schema", default = "sql_server_default_schema")]
61 pub schema: String,
62 #[serde(rename = "sqlserver.table")]
63 pub table: String,
64 #[serde(
65 rename = "sqlserver.max_batch_rows",
66 default = "default_max_batch_rows"
67 )]
68 #[serde_as(as = "DisplayFromStr")]
69 pub max_batch_rows: usize,
70 pub r#type: String, }
72
73pub fn sql_server_default_schema() -> String {
74 "dbo".to_owned()
75}
76
77impl SqlServerConfig {
78 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
79 let config =
80 serde_json::from_value::<SqlServerConfig>(serde_json::to_value(properties).unwrap())
81 .map_err(|e| SinkError::Config(anyhow!(e)))?;
82 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
83 return Err(SinkError::Config(anyhow!(
84 "`{}` must be {}, or {}",
85 SINK_TYPE_OPTION,
86 SINK_TYPE_APPEND_ONLY,
87 SINK_TYPE_UPSERT
88 )));
89 }
90 Ok(config)
91 }
92
93 pub fn full_object_path(&self) -> String {
94 format!("[{}].[{}].[{}]", self.database, self.schema, self.table)
95 }
96}
97
98impl EnforceSecret for SqlServerConfig {
99 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
100 "sqlserver.password"
101 };
102}
103#[derive(Debug)]
104pub struct SqlServerSink {
105 pub config: SqlServerConfig,
106 schema: Schema,
107 pk_indices: Vec<usize>,
108 is_append_only: bool,
109}
110
111impl EnforceSecret for SqlServerSink {
112 fn enforce_secret<'a>(
113 prop_iter: impl Iterator<Item = &'a str>,
114 ) -> crate::sink::ConnectorResult<()> {
115 for prop in prop_iter {
116 SqlServerConfig::enforce_one(prop)?;
117 }
118 Ok(())
119 }
120}
121impl SqlServerSink {
122 pub fn new(
123 mut config: SqlServerConfig,
124 schema: Schema,
125 pk_indices: Vec<usize>,
126 is_append_only: bool,
127 ) -> Result<Self> {
128 const TIBERIUS_PARAM_MAX: usize = 2000;
130 let params_per_op = schema.fields().len();
131 let tiberius_max_batch_rows = if params_per_op == 0 {
132 config.max_batch_rows
133 } else {
134 ((TIBERIUS_PARAM_MAX as f64 / params_per_op as f64).floor()) as usize
135 };
136 if tiberius_max_batch_rows == 0 {
137 return Err(SinkError::SqlServer(anyhow!(format!(
138 "too many column {}",
139 params_per_op
140 ))));
141 }
142 config.max_batch_rows = std::cmp::min(config.max_batch_rows, tiberius_max_batch_rows);
143 Ok(Self {
144 config,
145 schema,
146 pk_indices,
147 is_append_only,
148 })
149 }
150}
151
152impl TryFrom<SinkParam> for SqlServerSink {
153 type Error = SinkError;
154
155 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
156 let schema = param.schema();
157 let config = SqlServerConfig::from_btreemap(param.properties)?;
158 SqlServerSink::new(
159 config,
160 schema,
161 param.downstream_pk,
162 param.sink_type.is_append_only(),
163 )
164 }
165}
166
167impl Sink for SqlServerSink {
168 type LogSinker = LogSinkerOf<SqlServerSinkWriter>;
169
170 const SINK_NAME: &'static str = SQLSERVER_SINK;
171
172 async fn validate(&self) -> Result<()> {
173 risingwave_common::license::Feature::SqlServerSink
174 .check_available()
175 .map_err(|e| anyhow::anyhow!(e))?;
176
177 if !self.is_append_only && self.pk_indices.is_empty() {
178 return Err(SinkError::Config(anyhow!(
179 "Primary key not defined for upsert SQL Server sink (please define in `primary_key` field)"
180 )));
181 }
182
183 for f in self.schema.fields() {
184 check_data_type_compatibility(&f.data_type)?;
185 }
186
187 let mut sql_server_table_metadata = HashMap::new();
189 let mut sql_client = SqlServerClient::new(&self.config).await?;
190 let query_table_metadata_error = || {
191 SinkError::SqlServer(anyhow!(format!(
192 "SQL Server table {} metadata error",
193 self.config.full_object_path()
194 )))
195 };
196 static QUERY_TABLE_METADATA: &str = r#"
197SELECT
198 col.name AS ColumnName,
199 pk.index_id AS PkIndex
200FROM
201 sys.columns col
202LEFT JOIN
203 sys.index_columns ic ON ic.object_id = col.object_id AND ic.column_id = col.column_id
204LEFT JOIN
205 sys.indexes pk ON pk.object_id = col.object_id AND pk.index_id = ic.index_id AND pk.is_primary_key = 1
206WHERE
207 col.object_id = OBJECT_ID(@P1)
208ORDER BY
209 col.column_id;"#;
210 let rows = sql_client
211 .inner_client
212 .query(QUERY_TABLE_METADATA, &[&self.config.full_object_path()])
213 .await?
214 .into_results()
215 .await?;
216 for row in rows.into_iter().flatten() {
217 let mut iter = row.into_iter();
218 let ColumnData::String(Some(col_name)) =
219 iter.next().ok_or_else(query_table_metadata_error)?
220 else {
221 return Err(query_table_metadata_error());
222 };
223 let ColumnData::I32(col_pk_index) =
224 iter.next().ok_or_else(query_table_metadata_error)?
225 else {
226 return Err(query_table_metadata_error());
227 };
228 sql_server_table_metadata.insert(col_name.into_owned(), col_pk_index.is_some());
229 }
230
231 for (idx, col) in self.schema.fields().iter().enumerate() {
233 let rw_is_pk = self.pk_indices.contains(&idx);
234 match sql_server_table_metadata.get(&col.name) {
235 None => {
236 return Err(SinkError::SqlServer(anyhow!(format!(
237 "column {} not found in the downstream SQL Server table {}",
238 col.name,
239 self.config.full_object_path()
240 ))));
241 }
242 Some(sql_server_is_pk) => {
243 if self.is_append_only {
244 continue;
245 }
246 if rw_is_pk && !*sql_server_is_pk {
247 return Err(SinkError::SqlServer(anyhow!(format!(
248 "column {} specified in primary_key mismatches with the downstream SQL Server table {} PK",
249 col.name,
250 self.config.full_object_path(),
251 ))));
252 }
253 if !rw_is_pk && *sql_server_is_pk {
254 return Err(SinkError::SqlServer(anyhow!(format!(
255 "column {} unspecified in primary_key mismatches with the downstream SQL Server table {} PK",
256 col.name,
257 self.config.full_object_path(),
258 ))));
259 }
260 }
261 }
262 }
263
264 if !self.is_append_only {
265 let sql_server_pk_count = sql_server_table_metadata
266 .values()
267 .filter(|is_pk| **is_pk)
268 .count();
269 if sql_server_pk_count != self.pk_indices.len() {
270 return Err(SinkError::SqlServer(anyhow!(format!(
271 "primary key does not match between RisingWave sink ({}) and SQL Server table {} ({})",
272 self.pk_indices.len(),
273 self.config.full_object_path(),
274 sql_server_pk_count,
275 ))));
276 }
277 }
278
279 Ok(())
280 }
281
282 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
283 Ok(SqlServerSinkWriter::new(
284 self.config.clone(),
285 self.schema.clone(),
286 self.pk_indices.clone(),
287 self.is_append_only,
288 )
289 .await?
290 .into_log_sinker(SinkWriterMetrics::new(&writer_param)))
291 }
292}
293
294enum SqlOp {
295 Insert(OwnedRow),
296 Merge(OwnedRow),
297 Delete(OwnedRow),
298}
299
300pub struct SqlServerSinkWriter {
301 config: SqlServerConfig,
302 schema: Schema,
303 pk_indices: Vec<usize>,
304 is_append_only: bool,
305 sql_client: SqlServerClient,
306 ops: Vec<SqlOp>,
307}
308
309impl SqlServerSinkWriter {
310 async fn new(
311 config: SqlServerConfig,
312 schema: Schema,
313 pk_indices: Vec<usize>,
314 is_append_only: bool,
315 ) -> Result<Self> {
316 let sql_client = SqlServerClient::new(&config).await?;
317 let writer = Self {
318 config,
319 schema,
320 pk_indices,
321 is_append_only,
322 sql_client,
323 ops: vec![],
324 };
325 Ok(writer)
326 }
327
328 async fn delete_one(&mut self, row: RowRef<'_>) -> Result<()> {
329 if self.ops.len() + 1 >= self.config.max_batch_rows {
330 self.flush().await?;
331 }
332 self.ops.push(SqlOp::Delete(row.into_owned_row()));
333 Ok(())
334 }
335
336 async fn upsert_one(&mut self, row: RowRef<'_>) -> Result<()> {
337 if self.ops.len() + 1 >= self.config.max_batch_rows {
338 self.flush().await?;
339 }
340 self.ops.push(SqlOp::Merge(row.into_owned_row()));
341 Ok(())
342 }
343
344 async fn insert_one(&mut self, row: RowRef<'_>) -> Result<()> {
345 if self.ops.len() + 1 >= self.config.max_batch_rows {
346 self.flush().await?;
347 }
348 self.ops.push(SqlOp::Insert(row.into_owned_row()));
349 Ok(())
350 }
351
352 async fn flush(&mut self) -> Result<()> {
353 use std::fmt::Write;
354 if self.ops.is_empty() {
355 return Ok(());
356 }
357 let mut query_str = String::new();
358 let col_num = self.schema.fields.len();
359 let mut next_param_id = 1;
360 let non_pk_col_indices = (0..col_num)
361 .filter(|idx| !self.pk_indices.contains(idx))
362 .collect::<Vec<usize>>();
363 let all_col_names = self
364 .schema
365 .fields
366 .iter()
367 .map(|f| format!("[{}]", f.name))
368 .collect::<Vec<_>>()
369 .join(",");
370 let all_source_col_names = self
371 .schema
372 .fields
373 .iter()
374 .map(|f| format!("[SOURCE].[{}]", f.name))
375 .collect::<Vec<_>>()
376 .join(",");
377 let pk_match = self
378 .pk_indices
379 .iter()
380 .map(|idx| {
381 format!(
382 "[SOURCE].[{}]=[TARGET].[{}]",
383 &self.schema[*idx].name, &self.schema[*idx].name
384 )
385 })
386 .collect::<Vec<_>>()
387 .join(" AND ");
388 let param_placeholders = |param_id: &mut usize| {
389 let params = (*param_id..(*param_id + col_num))
390 .map(|i| format!("@P{}", i))
391 .collect::<Vec<_>>()
392 .join(",");
393 *param_id += col_num;
394 params
395 };
396 let set_all_source_col = non_pk_col_indices
397 .iter()
398 .map(|idx| {
399 format!(
400 "[{}]=[SOURCE].[{}]",
401 &self.schema[*idx].name, &self.schema[*idx].name
402 )
403 })
404 .collect::<Vec<_>>()
405 .join(",");
406 for op in &self.ops {
408 match op {
409 SqlOp::Insert(_) => {
410 write!(
411 &mut query_str,
412 "INSERT INTO {} ({}) VALUES ({});",
413 self.config.full_object_path(),
414 all_col_names,
415 param_placeholders(&mut next_param_id),
416 )
417 .unwrap();
418 }
419 SqlOp::Merge(_) => {
420 write!(
421 &mut query_str,
422 r#"MERGE {} AS [TARGET]
423 USING (VALUES ({})) AS [SOURCE] ({})
424 ON {}
425 WHEN MATCHED THEN UPDATE SET {}
426 WHEN NOT MATCHED THEN INSERT ({}) VALUES ({});"#,
427 self.config.full_object_path(),
428 param_placeholders(&mut next_param_id),
429 all_col_names,
430 pk_match,
431 set_all_source_col,
432 all_col_names,
433 all_source_col_names,
434 )
435 .unwrap();
436 }
437 SqlOp::Delete(_) => {
438 write!(
439 &mut query_str,
440 r#"DELETE FROM {} WHERE {};"#,
441 self.config.full_object_path(),
442 self.pk_indices
443 .iter()
444 .map(|idx| {
445 let condition =
446 format!("[{}]=@P{}", self.schema[*idx].name, next_param_id);
447 next_param_id += 1;
448 condition
449 })
450 .collect::<Vec<_>>()
451 .join(" AND "),
452 )
453 .unwrap();
454 }
455 }
456 }
457
458 let mut query = Query::new(query_str);
459 for op in self.ops.drain(..) {
460 match op {
461 SqlOp::Insert(row) => {
462 bind_params(&mut query, row, &self.schema, 0..col_num)?;
463 }
464 SqlOp::Merge(row) => {
465 bind_params(&mut query, row, &self.schema, 0..col_num)?;
466 }
467 SqlOp::Delete(row) => {
468 bind_params(
469 &mut query,
470 row,
471 &self.schema,
472 self.pk_indices.iter().copied(),
473 )?;
474 }
475 }
476 }
477 query.execute(&mut self.sql_client.inner_client).await?;
478 Ok(())
479 }
480}
481
482#[async_trait]
483impl SinkWriter for SqlServerSinkWriter {
484 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
485 Ok(())
486 }
487
488 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
489 for (op, row) in chunk.rows() {
490 match op {
491 Op::Insert => {
492 if self.is_append_only {
493 self.insert_one(row).await?;
494 } else {
495 self.upsert_one(row).await?;
496 }
497 }
498 Op::UpdateInsert => {
499 debug_assert!(!self.is_append_only);
500 self.upsert_one(row).await?;
501 }
502 Op::Delete => {
503 debug_assert!(!self.is_append_only);
504 self.delete_one(row).await?;
505 }
506 Op::UpdateDelete => {}
507 }
508 }
509 Ok(())
510 }
511
512 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Self::CommitMetadata> {
513 if is_checkpoint {
514 self.flush().await?;
515 }
516 Ok(())
517 }
518}
519
520#[derive(Debug)]
521pub struct SqlServerClient {
522 pub inner_client: Client<tokio_util::compat::Compat<TcpStream>>,
523}
524
525impl SqlServerClient {
526 async fn new(msconfig: &SqlServerConfig) -> Result<Self> {
527 let mut config = Config::new();
528 config.host(&msconfig.host);
529 config.port(msconfig.port);
530 config.authentication(AuthMethod::sql_server(&msconfig.user, &msconfig.password));
531 config.database(&msconfig.database);
532 config.trust_cert();
533 Self::new_with_config(config).await
534 }
535
536 pub async fn new_with_config(mut config: Config) -> Result<Self> {
537 let tcp = TcpStream::connect(config.get_addr())
538 .await
539 .context("failed to connect to sql server")
540 .map_err(SinkError::SqlServer)?;
541 tcp.set_nodelay(true)
542 .context("failed to setting nodelay when connecting to sql server")
543 .map_err(SinkError::SqlServer)?;
544
545 let client = match Client::connect(config.clone(), tcp.compat_write()).await {
546 Ok(client) => client,
548 Err(tiberius::error::Error::Routing { host, port }) => {
550 config.host(&host);
551 config.port(port);
552 let tcp = TcpStream::connect(config.get_addr())
553 .await
554 .context("failed to connect to sql server after routing")
555 .map_err(SinkError::SqlServer)?;
556 tcp.set_nodelay(true)
557 .context(
558 "failed to setting nodelay when connecting to sql server after routing",
559 )
560 .map_err(SinkError::SqlServer)?;
561 Client::connect(config, tcp.compat_write()).await?
563 }
564 Err(e) => return Err(e.into()),
565 };
566
567 Ok(Self {
568 inner_client: client,
569 })
570 }
571}
572
573fn bind_params(
574 query: &mut Query<'_>,
575 row: impl Row,
576 schema: &Schema,
577 col_indices: impl Iterator<Item = usize>,
578) -> Result<()> {
579 use risingwave_common::types::ScalarRefImpl;
580 for col_idx in col_indices {
581 match row.datum_at(col_idx) {
582 Some(data_ref) => match data_ref {
583 ScalarRefImpl::Int16(v) => query.bind(v),
584 ScalarRefImpl::Int32(v) => query.bind(v),
585 ScalarRefImpl::Int64(v) => query.bind(v),
586 ScalarRefImpl::Float32(v) => query.bind(v.into_inner()),
587 ScalarRefImpl::Float64(v) => query.bind(v.into_inner()),
588 ScalarRefImpl::Utf8(v) => query.bind(v.to_owned()),
589 ScalarRefImpl::Bool(v) => query.bind(v),
590 ScalarRefImpl::Decimal(v) => match v {
591 Decimal::Normalized(d) => {
592 query.bind(decimal_to_sql(&d));
593 }
594 Decimal::NaN | Decimal::PositiveInf | Decimal::NegativeInf => {
595 tracing::warn!(
596 "Inf, -Inf, Nan in RisingWave decimal is converted into SQL Server null!"
597 );
598 query.bind(None as Option<Numeric>);
599 }
600 },
601 ScalarRefImpl::Date(v) => query.bind(v.0),
602 ScalarRefImpl::Timestamp(v) => query.bind(v.0),
603 ScalarRefImpl::Timestamptz(v) => query.bind(v.timestamp_micros()),
604 ScalarRefImpl::Time(v) => query.bind(v.0),
605 ScalarRefImpl::Bytea(v) => query.bind(v.to_vec()),
606 ScalarRefImpl::Interval(_) => return Err(data_type_not_supported("Interval")),
607 ScalarRefImpl::Jsonb(_) => return Err(data_type_not_supported("Jsonb")),
608 ScalarRefImpl::Struct(_) => return Err(data_type_not_supported("Struct")),
609 ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")),
610 ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")),
611 ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")),
612 ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")),
613 ScalarRefImpl::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
614 },
615 None => match schema[col_idx].data_type {
616 DataType::Boolean => {
617 query.bind(None as Option<bool>);
618 }
619 DataType::Int16 => {
620 query.bind(None as Option<i16>);
621 }
622 DataType::Int32 => {
623 query.bind(None as Option<i32>);
624 }
625 DataType::Int64 => {
626 query.bind(None as Option<i64>);
627 }
628 DataType::Float32 => {
629 query.bind(None as Option<f32>);
630 }
631 DataType::Float64 => {
632 query.bind(None as Option<f64>);
633 }
634 DataType::Decimal => {
635 query.bind(None as Option<Numeric>);
636 }
637 DataType::Date => {
638 query.bind(None as Option<chrono::NaiveDate>);
639 }
640 DataType::Time => {
641 query.bind(None as Option<chrono::NaiveTime>);
642 }
643 DataType::Timestamp => {
644 query.bind(None as Option<chrono::NaiveDateTime>);
645 }
646 DataType::Timestamptz => {
647 query.bind(None as Option<i64>);
648 }
649 DataType::Varchar => {
650 query.bind(None as Option<String>);
651 }
652 DataType::Bytea => {
653 query.bind(None as Option<Vec<u8>>);
654 }
655 DataType::Interval => return Err(data_type_not_supported("Interval")),
656 DataType::Struct(_) => return Err(data_type_not_supported("Struct")),
657 DataType::List(_) => return Err(data_type_not_supported("List")),
658 DataType::Jsonb => return Err(data_type_not_supported("Jsonb")),
659 DataType::Serial => return Err(data_type_not_supported("Serial")),
660 DataType::Int256 => return Err(data_type_not_supported("Int256")),
661 DataType::Map(_) => return Err(data_type_not_supported("Map")),
662 DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
663 },
664 };
665 }
666 Ok(())
667}
668
669fn data_type_not_supported(data_type_name: &str) -> SinkError {
670 SinkError::SqlServer(anyhow!(format!(
671 "{data_type_name} is not supported in SQL Server"
672 )))
673}
674
675fn check_data_type_compatibility(data_type: &DataType) -> Result<()> {
676 match data_type {
677 DataType::Boolean
678 | DataType::Int16
679 | DataType::Int32
680 | DataType::Int64
681 | DataType::Float32
682 | DataType::Float64
683 | DataType::Decimal
684 | DataType::Date
685 | DataType::Varchar
686 | DataType::Time
687 | DataType::Timestamp
688 | DataType::Timestamptz
689 | DataType::Bytea => Ok(()),
690 DataType::Interval => Err(data_type_not_supported("Interval")),
691 DataType::Struct(_) => Err(data_type_not_supported("Struct")),
692 DataType::List(_) => Err(data_type_not_supported("List")),
693 DataType::Jsonb => Err(data_type_not_supported("Jsonb")),
694 DataType::Serial => Err(data_type_not_supported("Serial")),
695 DataType::Int256 => Err(data_type_not_supported("Int256")),
696 DataType::Map(_) => Err(data_type_not_supported("Map")),
697 DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
698 }
699}
700
701fn decimal_to_sql(decimal: &rust_decimal::Decimal) -> Numeric {
703 let unpacked = decimal.unpack();
704
705 let mut value = (((unpacked.hi as u128) << 64)
706 + ((unpacked.mid as u128) << 32)
707 + unpacked.lo as u128) as i128;
708
709 if decimal.is_sign_negative() {
710 value = -value;
711 }
712
713 Numeric::new_with_scale(value, decimal.scale() as u8)
714}