1use std::collections::{BTreeMap, HashMap};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use phf::{Set, phf_set};
20use risingwave_common::array::{Op, RowRef, StreamChunk};
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::{OwnedRow, Row};
23use risingwave_common::types::{DataType, Decimal};
24use serde_derive::Deserialize;
25use serde_with::{DisplayFromStr, serde_as};
26use simd_json::prelude::ArrayTrait;
27use tiberius::numeric::Numeric;
28use tiberius::{AuthMethod, Client, ColumnData, Config, Query};
29use tokio::net::TcpStream;
30use tokio_util::compat::TokioAsyncWriteCompatExt;
31use with_options::WithOptions;
32
33use super::{
34 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkWriterMetrics,
35};
36use crate::enforce_secret::EnforceSecret;
37use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
38use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
39
40pub const SQLSERVER_SINK: &str = "sqlserver";
41
42fn default_max_batch_rows() -> usize {
43 1024
44}
45
46#[serde_as]
47#[derive(Clone, Debug, Deserialize, WithOptions)]
48pub struct SqlServerConfig {
49 #[serde(rename = "sqlserver.host")]
50 pub host: String,
51 #[serde(rename = "sqlserver.port")]
52 #[serde_as(as = "DisplayFromStr")]
53 pub port: u16,
54 #[serde(rename = "sqlserver.user")]
55 pub user: String,
56 #[serde(rename = "sqlserver.password")]
57 pub password: String,
58 #[serde(rename = "sqlserver.database")]
59 pub database: String,
60 #[serde(rename = "sqlserver.schema", default = "sql_server_default_schema")]
61 pub schema: String,
62 #[serde(rename = "sqlserver.table")]
63 pub table: String,
64 #[serde(
65 rename = "sqlserver.max_batch_rows",
66 default = "default_max_batch_rows"
67 )]
68 #[serde_as(as = "DisplayFromStr")]
69 pub max_batch_rows: usize,
70 pub r#type: String, }
72
73pub fn sql_server_default_schema() -> String {
74 "dbo".to_owned()
75}
76
77impl SqlServerConfig {
78 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
79 let config =
80 serde_json::from_value::<SqlServerConfig>(serde_json::to_value(properties).unwrap())
81 .map_err(|e| SinkError::Config(anyhow!(e)))?;
82 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
83 return Err(SinkError::Config(anyhow!(
84 "`{}` must be {}, or {}",
85 SINK_TYPE_OPTION,
86 SINK_TYPE_APPEND_ONLY,
87 SINK_TYPE_UPSERT
88 )));
89 }
90 Ok(config)
91 }
92
93 pub fn full_object_path(&self) -> String {
94 format!("[{}].[{}].[{}]", self.database, self.schema, self.table)
95 }
96}
97
98impl EnforceSecret for SqlServerConfig {
99 const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
100 "sqlserver.password"
101 };
102}
103#[derive(Debug)]
104pub struct SqlServerSink {
105 pub config: SqlServerConfig,
106 schema: Schema,
107 pk_indices: Vec<usize>,
108 is_append_only: bool,
109}
110
111impl EnforceSecret for SqlServerSink {
112 fn enforce_secret<'a>(
113 prop_iter: impl Iterator<Item = &'a str>,
114 ) -> crate::sink::ConnectorResult<()> {
115 for prop in prop_iter {
116 SqlServerConfig::enforce_one(prop)?;
117 }
118 Ok(())
119 }
120}
121impl SqlServerSink {
122 pub fn new(
123 mut config: SqlServerConfig,
124 schema: Schema,
125 pk_indices: Vec<usize>,
126 is_append_only: bool,
127 ) -> Result<Self> {
128 const TIBERIUS_PARAM_MAX: usize = 2000;
130 let params_per_op = schema.fields().len();
131 let tiberius_max_batch_rows = if params_per_op == 0 {
132 config.max_batch_rows
133 } else {
134 ((TIBERIUS_PARAM_MAX as f64 / params_per_op as f64).floor()) as usize
135 };
136 if tiberius_max_batch_rows == 0 {
137 return Err(SinkError::SqlServer(anyhow!(format!(
138 "too many column {}",
139 params_per_op
140 ))));
141 }
142 config.max_batch_rows = std::cmp::min(config.max_batch_rows, tiberius_max_batch_rows);
143 Ok(Self {
144 config,
145 schema,
146 pk_indices,
147 is_append_only,
148 })
149 }
150}
151
152impl TryFrom<SinkParam> for SqlServerSink {
153 type Error = SinkError;
154
155 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
156 let schema = param.schema();
157 let config = SqlServerConfig::from_btreemap(param.properties)?;
158 SqlServerSink::new(
159 config,
160 schema,
161 param.downstream_pk,
162 param.sink_type.is_append_only(),
163 )
164 }
165}
166
167impl Sink for SqlServerSink {
168 type Coordinator = DummySinkCommitCoordinator;
169 type LogSinker = LogSinkerOf<SqlServerSinkWriter>;
170
171 const SINK_NAME: &'static str = SQLSERVER_SINK;
172
173 async fn validate(&self) -> Result<()> {
174 risingwave_common::license::Feature::SqlServerSink
175 .check_available()
176 .map_err(|e| anyhow::anyhow!(e))?;
177
178 if !self.is_append_only && self.pk_indices.is_empty() {
179 return Err(SinkError::Config(anyhow!(
180 "Primary key not defined for upsert SQL Server sink (please define in `primary_key` field)"
181 )));
182 }
183
184 for f in self.schema.fields() {
185 check_data_type_compatibility(&f.data_type)?;
186 }
187
188 let mut sql_server_table_metadata = HashMap::new();
190 let mut sql_client = SqlServerClient::new(&self.config).await?;
191 let query_table_metadata_error = || {
192 SinkError::SqlServer(anyhow!(format!(
193 "SQL Server table {} metadata error",
194 self.config.full_object_path()
195 )))
196 };
197 static QUERY_TABLE_METADATA: &str = r#"
198SELECT
199 col.name AS ColumnName,
200 pk.index_id AS PkIndex
201FROM
202 sys.columns col
203LEFT JOIN
204 sys.index_columns ic ON ic.object_id = col.object_id AND ic.column_id = col.column_id
205LEFT JOIN
206 sys.indexes pk ON pk.object_id = col.object_id AND pk.index_id = ic.index_id AND pk.is_primary_key = 1
207WHERE
208 col.object_id = OBJECT_ID(@P1)
209ORDER BY
210 col.column_id;"#;
211 let rows = sql_client
212 .inner_client
213 .query(QUERY_TABLE_METADATA, &[&self.config.full_object_path()])
214 .await?
215 .into_results()
216 .await?;
217 for row in rows.into_iter().flatten() {
218 let mut iter = row.into_iter();
219 let ColumnData::String(Some(col_name)) =
220 iter.next().ok_or_else(query_table_metadata_error)?
221 else {
222 return Err(query_table_metadata_error());
223 };
224 let ColumnData::I32(col_pk_index) =
225 iter.next().ok_or_else(query_table_metadata_error)?
226 else {
227 return Err(query_table_metadata_error());
228 };
229 sql_server_table_metadata.insert(col_name.into_owned(), col_pk_index.is_some());
230 }
231
232 for (idx, col) in self.schema.fields().iter().enumerate() {
234 let rw_is_pk = self.pk_indices.contains(&idx);
235 match sql_server_table_metadata.get(&col.name) {
236 None => {
237 return Err(SinkError::SqlServer(anyhow!(format!(
238 "column {} not found in the downstream SQL Server table {}",
239 col.name,
240 self.config.full_object_path()
241 ))));
242 }
243 Some(sql_server_is_pk) => {
244 if self.is_append_only {
245 continue;
246 }
247 if rw_is_pk && !*sql_server_is_pk {
248 return Err(SinkError::SqlServer(anyhow!(format!(
249 "column {} specified in primary_key mismatches with the downstream SQL Server table {} PK",
250 col.name,
251 self.config.full_object_path(),
252 ))));
253 }
254 if !rw_is_pk && *sql_server_is_pk {
255 return Err(SinkError::SqlServer(anyhow!(format!(
256 "column {} unspecified in primary_key mismatches with the downstream SQL Server table {} PK",
257 col.name,
258 self.config.full_object_path(),
259 ))));
260 }
261 }
262 }
263 }
264
265 if !self.is_append_only {
266 let sql_server_pk_count = sql_server_table_metadata
267 .values()
268 .filter(|is_pk| **is_pk)
269 .count();
270 if sql_server_pk_count != self.pk_indices.len() {
271 return Err(SinkError::SqlServer(anyhow!(format!(
272 "primary key does not match between RisingWave sink ({}) and SQL Server table {} ({})",
273 self.pk_indices.len(),
274 self.config.full_object_path(),
275 sql_server_pk_count,
276 ))));
277 }
278 }
279
280 Ok(())
281 }
282
283 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
284 Ok(SqlServerSinkWriter::new(
285 self.config.clone(),
286 self.schema.clone(),
287 self.pk_indices.clone(),
288 self.is_append_only,
289 )
290 .await?
291 .into_log_sinker(SinkWriterMetrics::new(&writer_param)))
292 }
293}
294
295enum SqlOp {
296 Insert(OwnedRow),
297 Merge(OwnedRow),
298 Delete(OwnedRow),
299}
300
301pub struct SqlServerSinkWriter {
302 config: SqlServerConfig,
303 schema: Schema,
304 pk_indices: Vec<usize>,
305 is_append_only: bool,
306 sql_client: SqlServerClient,
307 ops: Vec<SqlOp>,
308}
309
310impl SqlServerSinkWriter {
311 async fn new(
312 config: SqlServerConfig,
313 schema: Schema,
314 pk_indices: Vec<usize>,
315 is_append_only: bool,
316 ) -> Result<Self> {
317 let sql_client = SqlServerClient::new(&config).await?;
318 let writer = Self {
319 config,
320 schema,
321 pk_indices,
322 is_append_only,
323 sql_client,
324 ops: vec![],
325 };
326 Ok(writer)
327 }
328
329 async fn delete_one(&mut self, row: RowRef<'_>) -> Result<()> {
330 if self.ops.len() + 1 >= self.config.max_batch_rows {
331 self.flush().await?;
332 }
333 self.ops.push(SqlOp::Delete(row.into_owned_row()));
334 Ok(())
335 }
336
337 async fn upsert_one(&mut self, row: RowRef<'_>) -> Result<()> {
338 if self.ops.len() + 1 >= self.config.max_batch_rows {
339 self.flush().await?;
340 }
341 self.ops.push(SqlOp::Merge(row.into_owned_row()));
342 Ok(())
343 }
344
345 async fn insert_one(&mut self, row: RowRef<'_>) -> Result<()> {
346 if self.ops.len() + 1 >= self.config.max_batch_rows {
347 self.flush().await?;
348 }
349 self.ops.push(SqlOp::Insert(row.into_owned_row()));
350 Ok(())
351 }
352
353 async fn flush(&mut self) -> Result<()> {
354 use std::fmt::Write;
355 if self.ops.is_empty() {
356 return Ok(());
357 }
358 let mut query_str = String::new();
359 let col_num = self.schema.fields.len();
360 let mut next_param_id = 1;
361 let non_pk_col_indices = (0..col_num)
362 .filter(|idx| !self.pk_indices.contains(idx))
363 .collect::<Vec<usize>>();
364 let all_col_names = self
365 .schema
366 .fields
367 .iter()
368 .map(|f| format!("[{}]", f.name))
369 .collect::<Vec<_>>()
370 .join(",");
371 let all_source_col_names = self
372 .schema
373 .fields
374 .iter()
375 .map(|f| format!("[SOURCE].[{}]", f.name))
376 .collect::<Vec<_>>()
377 .join(",");
378 let pk_match = self
379 .pk_indices
380 .iter()
381 .map(|idx| {
382 format!(
383 "[SOURCE].[{}]=[TARGET].[{}]",
384 &self.schema[*idx].name, &self.schema[*idx].name
385 )
386 })
387 .collect::<Vec<_>>()
388 .join(" AND ");
389 let param_placeholders = |param_id: &mut usize| {
390 let params = (*param_id..(*param_id + col_num))
391 .map(|i| format!("@P{}", i))
392 .collect::<Vec<_>>()
393 .join(",");
394 *param_id += col_num;
395 params
396 };
397 let set_all_source_col = non_pk_col_indices
398 .iter()
399 .map(|idx| {
400 format!(
401 "[{}]=[SOURCE].[{}]",
402 &self.schema[*idx].name, &self.schema[*idx].name
403 )
404 })
405 .collect::<Vec<_>>()
406 .join(",");
407 for op in &self.ops {
409 match op {
410 SqlOp::Insert(_) => {
411 write!(
412 &mut query_str,
413 "INSERT INTO {} ({}) VALUES ({});",
414 self.config.full_object_path(),
415 all_col_names,
416 param_placeholders(&mut next_param_id),
417 )
418 .unwrap();
419 }
420 SqlOp::Merge(_) => {
421 write!(
422 &mut query_str,
423 r#"MERGE {} AS [TARGET]
424 USING (VALUES ({})) AS [SOURCE] ({})
425 ON {}
426 WHEN MATCHED THEN UPDATE SET {}
427 WHEN NOT MATCHED THEN INSERT ({}) VALUES ({});"#,
428 self.config.full_object_path(),
429 param_placeholders(&mut next_param_id),
430 all_col_names,
431 pk_match,
432 set_all_source_col,
433 all_col_names,
434 all_source_col_names,
435 )
436 .unwrap();
437 }
438 SqlOp::Delete(_) => {
439 write!(
440 &mut query_str,
441 r#"DELETE FROM {} WHERE {};"#,
442 self.config.full_object_path(),
443 self.pk_indices
444 .iter()
445 .map(|idx| {
446 let condition =
447 format!("[{}]=@P{}", self.schema[*idx].name, next_param_id);
448 next_param_id += 1;
449 condition
450 })
451 .collect::<Vec<_>>()
452 .join(" AND "),
453 )
454 .unwrap();
455 }
456 }
457 }
458
459 let mut query = Query::new(query_str);
460 for op in self.ops.drain(..) {
461 match op {
462 SqlOp::Insert(row) => {
463 bind_params(&mut query, row, &self.schema, 0..col_num)?;
464 }
465 SqlOp::Merge(row) => {
466 bind_params(&mut query, row, &self.schema, 0..col_num)?;
467 }
468 SqlOp::Delete(row) => {
469 bind_params(
470 &mut query,
471 row,
472 &self.schema,
473 self.pk_indices.iter().copied(),
474 )?;
475 }
476 }
477 }
478 query.execute(&mut self.sql_client.inner_client).await?;
479 Ok(())
480 }
481}
482
483#[async_trait]
484impl SinkWriter for SqlServerSinkWriter {
485 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
486 Ok(())
487 }
488
489 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
490 for (op, row) in chunk.rows() {
491 match op {
492 Op::Insert => {
493 if self.is_append_only {
494 self.insert_one(row).await?;
495 } else {
496 self.upsert_one(row).await?;
497 }
498 }
499 Op::UpdateInsert => {
500 debug_assert!(!self.is_append_only);
501 self.upsert_one(row).await?;
502 }
503 Op::Delete => {
504 debug_assert!(!self.is_append_only);
505 self.delete_one(row).await?;
506 }
507 Op::UpdateDelete => {}
508 }
509 }
510 Ok(())
511 }
512
513 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Self::CommitMetadata> {
514 if is_checkpoint {
515 self.flush().await?;
516 }
517 Ok(())
518 }
519}
520
521#[derive(Debug)]
522pub struct SqlServerClient {
523 pub inner_client: Client<tokio_util::compat::Compat<TcpStream>>,
524}
525
526impl SqlServerClient {
527 async fn new(msconfig: &SqlServerConfig) -> Result<Self> {
528 let mut config = Config::new();
529 config.host(&msconfig.host);
530 config.port(msconfig.port);
531 config.authentication(AuthMethod::sql_server(&msconfig.user, &msconfig.password));
532 config.database(&msconfig.database);
533 config.trust_cert();
534 Self::new_with_config(config).await
535 }
536
537 pub async fn new_with_config(mut config: Config) -> Result<Self> {
538 let tcp = TcpStream::connect(config.get_addr())
539 .await
540 .context("failed to connect to sql server")
541 .map_err(SinkError::SqlServer)?;
542 tcp.set_nodelay(true)
543 .context("failed to setting nodelay when connecting to sql server")
544 .map_err(SinkError::SqlServer)?;
545
546 let client = match Client::connect(config.clone(), tcp.compat_write()).await {
547 Ok(client) => client,
549 Err(tiberius::error::Error::Routing { host, port }) => {
551 config.host(&host);
552 config.port(port);
553 let tcp = TcpStream::connect(config.get_addr())
554 .await
555 .context("failed to connect to sql server after routing")
556 .map_err(SinkError::SqlServer)?;
557 tcp.set_nodelay(true)
558 .context(
559 "failed to setting nodelay when connecting to sql server after routing",
560 )
561 .map_err(SinkError::SqlServer)?;
562 Client::connect(config, tcp.compat_write()).await?
564 }
565 Err(e) => return Err(e.into()),
566 };
567
568 Ok(Self {
569 inner_client: client,
570 })
571 }
572}
573
574fn bind_params(
575 query: &mut Query<'_>,
576 row: impl Row,
577 schema: &Schema,
578 col_indices: impl Iterator<Item = usize>,
579) -> Result<()> {
580 use risingwave_common::types::ScalarRefImpl;
581 for col_idx in col_indices {
582 match row.datum_at(col_idx) {
583 Some(data_ref) => match data_ref {
584 ScalarRefImpl::Int16(v) => query.bind(v),
585 ScalarRefImpl::Int32(v) => query.bind(v),
586 ScalarRefImpl::Int64(v) => query.bind(v),
587 ScalarRefImpl::Float32(v) => query.bind(v.into_inner()),
588 ScalarRefImpl::Float64(v) => query.bind(v.into_inner()),
589 ScalarRefImpl::Utf8(v) => query.bind(v.to_owned()),
590 ScalarRefImpl::Bool(v) => query.bind(v),
591 ScalarRefImpl::Decimal(v) => match v {
592 Decimal::Normalized(d) => {
593 query.bind(decimal_to_sql(&d));
594 }
595 Decimal::NaN | Decimal::PositiveInf | Decimal::NegativeInf => {
596 tracing::warn!(
597 "Inf, -Inf, Nan in RisingWave decimal is converted into SQL Server null!"
598 );
599 query.bind(None as Option<Numeric>);
600 }
601 },
602 ScalarRefImpl::Date(v) => query.bind(v.0),
603 ScalarRefImpl::Timestamp(v) => query.bind(v.0),
604 ScalarRefImpl::Timestamptz(v) => query.bind(v.timestamp_micros()),
605 ScalarRefImpl::Time(v) => query.bind(v.0),
606 ScalarRefImpl::Bytea(v) => query.bind(v.to_vec()),
607 ScalarRefImpl::Interval(_) => return Err(data_type_not_supported("Interval")),
608 ScalarRefImpl::Jsonb(_) => return Err(data_type_not_supported("Jsonb")),
609 ScalarRefImpl::Struct(_) => return Err(data_type_not_supported("Struct")),
610 ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")),
611 ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")),
612 ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")),
613 ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")),
614 },
615 None => match schema[col_idx].data_type {
616 DataType::Boolean => {
617 query.bind(None as Option<bool>);
618 }
619 DataType::Int16 => {
620 query.bind(None as Option<i16>);
621 }
622 DataType::Int32 => {
623 query.bind(None as Option<i32>);
624 }
625 DataType::Int64 => {
626 query.bind(None as Option<i64>);
627 }
628 DataType::Float32 => {
629 query.bind(None as Option<f32>);
630 }
631 DataType::Float64 => {
632 query.bind(None as Option<f64>);
633 }
634 DataType::Decimal => {
635 query.bind(None as Option<Numeric>);
636 }
637 DataType::Date => {
638 query.bind(None as Option<chrono::NaiveDate>);
639 }
640 DataType::Time => {
641 query.bind(None as Option<chrono::NaiveTime>);
642 }
643 DataType::Timestamp => {
644 query.bind(None as Option<chrono::NaiveDateTime>);
645 }
646 DataType::Timestamptz => {
647 query.bind(None as Option<i64>);
648 }
649 DataType::Varchar => {
650 query.bind(None as Option<String>);
651 }
652 DataType::Bytea => {
653 query.bind(None as Option<Vec<u8>>);
654 }
655 DataType::Interval => return Err(data_type_not_supported("Interval")),
656 DataType::Struct(_) => return Err(data_type_not_supported("Struct")),
657 DataType::List(_) => return Err(data_type_not_supported("List")),
658 DataType::Jsonb => return Err(data_type_not_supported("Jsonb")),
659 DataType::Serial => return Err(data_type_not_supported("Serial")),
660 DataType::Int256 => return Err(data_type_not_supported("Int256")),
661 DataType::Map(_) => return Err(data_type_not_supported("Map")),
662 },
663 };
664 }
665 Ok(())
666}
667
668fn data_type_not_supported(data_type_name: &str) -> SinkError {
669 SinkError::SqlServer(anyhow!(format!(
670 "{data_type_name} is not supported in SQL Server"
671 )))
672}
673
674fn check_data_type_compatibility(data_type: &DataType) -> Result<()> {
675 match data_type {
676 DataType::Boolean
677 | DataType::Int16
678 | DataType::Int32
679 | DataType::Int64
680 | DataType::Float32
681 | DataType::Float64
682 | DataType::Decimal
683 | DataType::Date
684 | DataType::Varchar
685 | DataType::Time
686 | DataType::Timestamp
687 | DataType::Timestamptz
688 | DataType::Bytea => Ok(()),
689 DataType::Interval => Err(data_type_not_supported("Interval")),
690 DataType::Struct(_) => Err(data_type_not_supported("Struct")),
691 DataType::List(_) => Err(data_type_not_supported("List")),
692 DataType::Jsonb => Err(data_type_not_supported("Jsonb")),
693 DataType::Serial => Err(data_type_not_supported("Serial")),
694 DataType::Int256 => Err(data_type_not_supported("Int256")),
695 DataType::Map(_) => Err(data_type_not_supported("Map")),
696 }
697}
698
699fn decimal_to_sql(decimal: &rust_decimal::Decimal) -> Numeric {
701 let unpacked = decimal.unpack();
702
703 let mut value = (((unpacked.hi as u128) << 64)
704 + ((unpacked.mid as u128) << 32)
705 + unpacked.lo as u128) as i128;
706
707 if decimal.is_sign_negative() {
708 value = -value;
709 }
710
711 Numeric::new_with_scale(value, decimal.scale() as u8)
712}