1use std::collections::{BTreeMap, HashMap};
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use risingwave_common::array::{Op, RowRef, StreamChunk};
20use risingwave_common::catalog::Schema;
21use risingwave_common::row::{OwnedRow, Row};
22use risingwave_common::types::{DataType, Decimal};
23use serde_derive::Deserialize;
24use serde_with::{DisplayFromStr, serde_as};
25use simd_json::prelude::ArrayTrait;
26use tiberius::numeric::Numeric;
27use tiberius::{AuthMethod, Client, ColumnData, Config, Query};
28use tokio::net::TcpStream;
29use tokio_util::compat::TokioAsyncWriteCompatExt;
30use with_options::WithOptions;
31
32use super::{
33 SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkWriterMetrics,
34};
35use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
36use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkParam, SinkWriterParam};
37
38pub const SQLSERVER_SINK: &str = "sqlserver";
39
40fn default_max_batch_rows() -> usize {
41 1024
42}
43
44#[serde_as]
45#[derive(Clone, Debug, Deserialize, WithOptions)]
46pub struct SqlServerConfig {
47 #[serde(rename = "sqlserver.host")]
48 pub host: String,
49 #[serde(rename = "sqlserver.port")]
50 #[serde_as(as = "DisplayFromStr")]
51 pub port: u16,
52 #[serde(rename = "sqlserver.user")]
53 pub user: String,
54 #[serde(rename = "sqlserver.password")]
55 pub password: String,
56 #[serde(rename = "sqlserver.database")]
57 pub database: String,
58 #[serde(rename = "sqlserver.table")]
59 pub table: String,
60 #[serde(
61 rename = "sqlserver.max_batch_rows",
62 default = "default_max_batch_rows"
63 )]
64 #[serde_as(as = "DisplayFromStr")]
65 pub max_batch_rows: usize,
66 pub r#type: String, }
68
69impl SqlServerConfig {
70 pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
71 let config =
72 serde_json::from_value::<SqlServerConfig>(serde_json::to_value(properties).unwrap())
73 .map_err(|e| SinkError::Config(anyhow!(e)))?;
74 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
75 return Err(SinkError::Config(anyhow!(
76 "`{}` must be {}, or {}",
77 SINK_TYPE_OPTION,
78 SINK_TYPE_APPEND_ONLY,
79 SINK_TYPE_UPSERT
80 )));
81 }
82 Ok(config)
83 }
84}
85
86#[derive(Debug)]
87pub struct SqlServerSink {
88 pub config: SqlServerConfig,
89 schema: Schema,
90 pk_indices: Vec<usize>,
91 is_append_only: bool,
92}
93
94impl SqlServerSink {
95 pub fn new(
96 mut config: SqlServerConfig,
97 schema: Schema,
98 pk_indices: Vec<usize>,
99 is_append_only: bool,
100 ) -> Result<Self> {
101 const TIBERIUS_PARAM_MAX: usize = 2000;
103 let params_per_op = schema.fields().len();
104 let tiberius_max_batch_rows = if params_per_op == 0 {
105 config.max_batch_rows
106 } else {
107 ((TIBERIUS_PARAM_MAX as f64 / params_per_op as f64).floor()) as usize
108 };
109 if tiberius_max_batch_rows == 0 {
110 return Err(SinkError::SqlServer(anyhow!(format!(
111 "too many column {}",
112 params_per_op
113 ))));
114 }
115 config.max_batch_rows = std::cmp::min(config.max_batch_rows, tiberius_max_batch_rows);
116 Ok(Self {
117 config,
118 schema,
119 pk_indices,
120 is_append_only,
121 })
122 }
123}
124
125impl TryFrom<SinkParam> for SqlServerSink {
126 type Error = SinkError;
127
128 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
129 let schema = param.schema();
130 let config = SqlServerConfig::from_btreemap(param.properties)?;
131 SqlServerSink::new(
132 config,
133 schema,
134 param.downstream_pk,
135 param.sink_type.is_append_only(),
136 )
137 }
138}
139
140impl Sink for SqlServerSink {
141 type Coordinator = DummySinkCommitCoordinator;
142 type LogSinker = LogSinkerOf<SqlServerSinkWriter>;
143
144 const SINK_NAME: &'static str = SQLSERVER_SINK;
145
146 async fn validate(&self) -> Result<()> {
147 risingwave_common::license::Feature::SqlServerSink
148 .check_available()
149 .map_err(|e| anyhow::anyhow!(e))?;
150
151 if !self.is_append_only && self.pk_indices.is_empty() {
152 return Err(SinkError::Config(anyhow!(
153 "Primary key not defined for upsert SQL Server sink (please define in `primary_key` field)"
154 )));
155 }
156
157 for f in self.schema.fields() {
158 check_data_type_compatibility(&f.data_type)?;
159 }
160
161 let mut sql_server_table_metadata = HashMap::new();
163 let mut sql_client = SqlServerClient::new(&self.config).await?;
164 let query_table_metadata_error = || {
165 SinkError::SqlServer(anyhow!(format!(
166 "SQL Server table {} metadata error",
167 self.config.table
168 )))
169 };
170 static QUERY_TABLE_METADATA: &str = r#"
171SELECT
172 col.name AS ColumnName,
173 pk.index_id AS PkIndex
174FROM
175 sys.columns col
176LEFT JOIN
177 sys.index_columns ic ON ic.object_id = col.object_id AND ic.column_id = col.column_id
178LEFT JOIN
179 sys.indexes pk ON pk.object_id = col.object_id AND pk.index_id = ic.index_id AND pk.is_primary_key = 1
180WHERE
181 col.object_id = OBJECT_ID(@P1)
182ORDER BY
183 col.column_id;"#;
184 let rows = sql_client
185 .inner_client
186 .query(QUERY_TABLE_METADATA, &[&self.config.table])
187 .await?
188 .into_results()
189 .await?;
190 for row in rows.into_iter().flatten() {
191 let mut iter = row.into_iter();
192 let ColumnData::String(Some(col_name)) =
193 iter.next().ok_or_else(query_table_metadata_error)?
194 else {
195 return Err(query_table_metadata_error());
196 };
197 let ColumnData::I32(col_pk_index) =
198 iter.next().ok_or_else(query_table_metadata_error)?
199 else {
200 return Err(query_table_metadata_error());
201 };
202 sql_server_table_metadata.insert(col_name.into_owned(), col_pk_index.is_some());
203 }
204
205 for (idx, col) in self.schema.fields().iter().enumerate() {
207 let rw_is_pk = self.pk_indices.contains(&idx);
208 match sql_server_table_metadata.get(&col.name) {
209 None => {
210 return Err(SinkError::SqlServer(anyhow!(format!(
211 "column {} not found in the downstream SQL Server table {}",
212 col.name, self.config.table
213 ))));
214 }
215 Some(sql_server_is_pk) => {
216 if self.is_append_only {
217 continue;
218 }
219 if rw_is_pk && !*sql_server_is_pk {
220 return Err(SinkError::SqlServer(anyhow!(format!(
221 "column {} specified in primary_key mismatches with the downstream SQL Server table {} PK",
222 col.name, self.config.table,
223 ))));
224 }
225 if !rw_is_pk && *sql_server_is_pk {
226 return Err(SinkError::SqlServer(anyhow!(format!(
227 "column {} unspecified in primary_key mismatches with the downstream SQL Server table {} PK",
228 col.name, self.config.table,
229 ))));
230 }
231 }
232 }
233 }
234
235 if !self.is_append_only {
236 let sql_server_pk_count = sql_server_table_metadata
237 .values()
238 .filter(|is_pk| **is_pk)
239 .count();
240 if sql_server_pk_count != self.pk_indices.len() {
241 return Err(SinkError::SqlServer(anyhow!(format!(
242 "primary key does not match between RisingWave sink ({}) and SQL Server table {} ({})",
243 self.pk_indices.len(),
244 self.config.table,
245 sql_server_pk_count,
246 ))));
247 }
248 }
249
250 Ok(())
251 }
252
253 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
254 Ok(SqlServerSinkWriter::new(
255 self.config.clone(),
256 self.schema.clone(),
257 self.pk_indices.clone(),
258 self.is_append_only,
259 )
260 .await?
261 .into_log_sinker(SinkWriterMetrics::new(&writer_param)))
262 }
263}
264
265enum SqlOp {
266 Insert(OwnedRow),
267 Merge(OwnedRow),
268 Delete(OwnedRow),
269}
270
271pub struct SqlServerSinkWriter {
272 config: SqlServerConfig,
273 schema: Schema,
274 pk_indices: Vec<usize>,
275 is_append_only: bool,
276 sql_client: SqlServerClient,
277 ops: Vec<SqlOp>,
278}
279
280impl SqlServerSinkWriter {
281 async fn new(
282 config: SqlServerConfig,
283 schema: Schema,
284 pk_indices: Vec<usize>,
285 is_append_only: bool,
286 ) -> Result<Self> {
287 let sql_client = SqlServerClient::new(&config).await?;
288 let writer = Self {
289 config,
290 schema,
291 pk_indices,
292 is_append_only,
293 sql_client,
294 ops: vec![],
295 };
296 Ok(writer)
297 }
298
299 async fn delete_one(&mut self, row: RowRef<'_>) -> Result<()> {
300 if self.ops.len() + 1 >= self.config.max_batch_rows {
301 self.flush().await?;
302 }
303 self.ops.push(SqlOp::Delete(row.into_owned_row()));
304 Ok(())
305 }
306
307 async fn upsert_one(&mut self, row: RowRef<'_>) -> Result<()> {
308 if self.ops.len() + 1 >= self.config.max_batch_rows {
309 self.flush().await?;
310 }
311 self.ops.push(SqlOp::Merge(row.into_owned_row()));
312 Ok(())
313 }
314
315 async fn insert_one(&mut self, row: RowRef<'_>) -> Result<()> {
316 if self.ops.len() + 1 >= self.config.max_batch_rows {
317 self.flush().await?;
318 }
319 self.ops.push(SqlOp::Insert(row.into_owned_row()));
320 Ok(())
321 }
322
323 async fn flush(&mut self) -> Result<()> {
324 use std::fmt::Write;
325 if self.ops.is_empty() {
326 return Ok(());
327 }
328 let mut query_str = String::new();
329 let col_num = self.schema.fields.len();
330 let mut next_param_id = 1;
331 let non_pk_col_indices = (0..col_num)
332 .filter(|idx| !self.pk_indices.contains(idx))
333 .collect::<Vec<usize>>();
334 let all_col_names = self
335 .schema
336 .fields
337 .iter()
338 .map(|f| format!("[{}]", f.name))
339 .collect::<Vec<_>>()
340 .join(",");
341 let all_source_col_names = self
342 .schema
343 .fields
344 .iter()
345 .map(|f| format!("[SOURCE].[{}]", f.name))
346 .collect::<Vec<_>>()
347 .join(",");
348 let pk_match = self
349 .pk_indices
350 .iter()
351 .map(|idx| {
352 format!(
353 "[SOURCE].[{}]=[TARGET].[{}]",
354 &self.schema[*idx].name, &self.schema[*idx].name
355 )
356 })
357 .collect::<Vec<_>>()
358 .join(" AND ");
359 let param_placeholders = |param_id: &mut usize| {
360 let params = (*param_id..(*param_id + col_num))
361 .map(|i| format!("@P{}", i))
362 .collect::<Vec<_>>()
363 .join(",");
364 *param_id += col_num;
365 params
366 };
367 let set_all_source_col = non_pk_col_indices
368 .iter()
369 .map(|idx| {
370 format!(
371 "[{}]=[SOURCE].[{}]",
372 &self.schema[*idx].name, &self.schema[*idx].name
373 )
374 })
375 .collect::<Vec<_>>()
376 .join(",");
377 for op in &self.ops {
379 match op {
380 SqlOp::Insert(_) => {
381 write!(
382 &mut query_str,
383 "INSERT INTO [{}] ({}) VALUES ({});",
384 self.config.table,
385 all_col_names,
386 param_placeholders(&mut next_param_id),
387 )
388 .unwrap();
389 }
390 SqlOp::Merge(_) => {
391 write!(
392 &mut query_str,
393 r#"MERGE [{}] AS [TARGET]
394 USING (VALUES ({})) AS [SOURCE] ({})
395 ON {}
396 WHEN MATCHED THEN UPDATE SET {}
397 WHEN NOT MATCHED THEN INSERT ({}) VALUES ({});"#,
398 self.config.table,
399 param_placeholders(&mut next_param_id),
400 all_col_names,
401 pk_match,
402 set_all_source_col,
403 all_col_names,
404 all_source_col_names,
405 )
406 .unwrap();
407 }
408 SqlOp::Delete(_) => {
409 write!(
410 &mut query_str,
411 r#"DELETE FROM [{}] WHERE {};"#,
412 self.config.table,
413 self.pk_indices
414 .iter()
415 .map(|idx| {
416 let condition =
417 format!("[{}]=@P{}", self.schema[*idx].name, next_param_id);
418 next_param_id += 1;
419 condition
420 })
421 .collect::<Vec<_>>()
422 .join(" AND "),
423 )
424 .unwrap();
425 }
426 }
427 }
428
429 let mut query = Query::new(query_str);
430 for op in self.ops.drain(..) {
431 match op {
432 SqlOp::Insert(row) => {
433 bind_params(&mut query, row, &self.schema, 0..col_num)?;
434 }
435 SqlOp::Merge(row) => {
436 bind_params(&mut query, row, &self.schema, 0..col_num)?;
437 }
438 SqlOp::Delete(row) => {
439 bind_params(
440 &mut query,
441 row,
442 &self.schema,
443 self.pk_indices.iter().copied(),
444 )?;
445 }
446 }
447 }
448 query.execute(&mut self.sql_client.inner_client).await?;
449 Ok(())
450 }
451}
452
453#[async_trait]
454impl SinkWriter for SqlServerSinkWriter {
455 async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> {
456 Ok(())
457 }
458
459 async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
460 for (op, row) in chunk.rows() {
461 match op {
462 Op::Insert => {
463 if self.is_append_only {
464 self.insert_one(row).await?;
465 } else {
466 self.upsert_one(row).await?;
467 }
468 }
469 Op::UpdateInsert => {
470 debug_assert!(!self.is_append_only);
471 self.upsert_one(row).await?;
472 }
473 Op::Delete => {
474 debug_assert!(!self.is_append_only);
475 self.delete_one(row).await?;
476 }
477 Op::UpdateDelete => {}
478 }
479 }
480 Ok(())
481 }
482
483 async fn barrier(&mut self, is_checkpoint: bool) -> Result<Self::CommitMetadata> {
484 if is_checkpoint {
485 self.flush().await?;
486 }
487 Ok(())
488 }
489}
490
491#[derive(Debug)]
492pub struct SqlServerClient {
493 pub inner_client: Client<tokio_util::compat::Compat<TcpStream>>,
494}
495
496impl SqlServerClient {
497 async fn new(msconfig: &SqlServerConfig) -> Result<Self> {
498 let mut config = Config::new();
499 config.host(&msconfig.host);
500 config.port(msconfig.port);
501 config.authentication(AuthMethod::sql_server(&msconfig.user, &msconfig.password));
502 config.database(&msconfig.database);
503 config.trust_cert();
504 Self::new_with_config(config).await
505 }
506
507 pub async fn new_with_config(mut config: Config) -> Result<Self> {
508 let tcp = TcpStream::connect(config.get_addr())
509 .await
510 .context("failed to connect to sql server")
511 .map_err(SinkError::SqlServer)?;
512 tcp.set_nodelay(true)
513 .context("failed to setting nodelay when connecting to sql server")
514 .map_err(SinkError::SqlServer)?;
515
516 let client = match Client::connect(config.clone(), tcp.compat_write()).await {
517 Ok(client) => client,
519 Err(tiberius::error::Error::Routing { host, port }) => {
521 config.host(&host);
522 config.port(port);
523 let tcp = TcpStream::connect(config.get_addr())
524 .await
525 .context("failed to connect to sql server after routing")
526 .map_err(SinkError::SqlServer)?;
527 tcp.set_nodelay(true)
528 .context(
529 "failed to setting nodelay when connecting to sql server after routing",
530 )
531 .map_err(SinkError::SqlServer)?;
532 Client::connect(config, tcp.compat_write()).await?
534 }
535 Err(e) => return Err(e.into()),
536 };
537
538 Ok(Self {
539 inner_client: client,
540 })
541 }
542}
543
544fn bind_params(
545 query: &mut Query<'_>,
546 row: impl Row,
547 schema: &Schema,
548 col_indices: impl Iterator<Item = usize>,
549) -> Result<()> {
550 use risingwave_common::types::ScalarRefImpl;
551 for col_idx in col_indices {
552 match row.datum_at(col_idx) {
553 Some(data_ref) => match data_ref {
554 ScalarRefImpl::Int16(v) => query.bind(v),
555 ScalarRefImpl::Int32(v) => query.bind(v),
556 ScalarRefImpl::Int64(v) => query.bind(v),
557 ScalarRefImpl::Float32(v) => query.bind(v.into_inner()),
558 ScalarRefImpl::Float64(v) => query.bind(v.into_inner()),
559 ScalarRefImpl::Utf8(v) => query.bind(v.to_owned()),
560 ScalarRefImpl::Bool(v) => query.bind(v),
561 ScalarRefImpl::Decimal(v) => match v {
562 Decimal::Normalized(d) => {
563 query.bind(decimal_to_sql(&d));
564 }
565 Decimal::NaN | Decimal::PositiveInf | Decimal::NegativeInf => {
566 tracing::warn!(
567 "Inf, -Inf, Nan in RisingWave decimal is converted into SQL Server null!"
568 );
569 query.bind(None as Option<Numeric>);
570 }
571 },
572 ScalarRefImpl::Date(v) => query.bind(v.0),
573 ScalarRefImpl::Timestamp(v) => query.bind(v.0),
574 ScalarRefImpl::Timestamptz(v) => query.bind(v.timestamp_micros()),
575 ScalarRefImpl::Time(v) => query.bind(v.0),
576 ScalarRefImpl::Bytea(v) => query.bind(v.to_vec()),
577 ScalarRefImpl::Interval(_) => return Err(data_type_not_supported("Interval")),
578 ScalarRefImpl::Jsonb(_) => return Err(data_type_not_supported("Jsonb")),
579 ScalarRefImpl::Struct(_) => return Err(data_type_not_supported("Struct")),
580 ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")),
581 ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")),
582 ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")),
583 ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")),
584 },
585 None => match schema[col_idx].data_type {
586 DataType::Boolean => {
587 query.bind(None as Option<bool>);
588 }
589 DataType::Int16 => {
590 query.bind(None as Option<i16>);
591 }
592 DataType::Int32 => {
593 query.bind(None as Option<i32>);
594 }
595 DataType::Int64 => {
596 query.bind(None as Option<i64>);
597 }
598 DataType::Float32 => {
599 query.bind(None as Option<f32>);
600 }
601 DataType::Float64 => {
602 query.bind(None as Option<f64>);
603 }
604 DataType::Decimal => {
605 query.bind(None as Option<Numeric>);
606 }
607 DataType::Date => {
608 query.bind(None as Option<chrono::NaiveDate>);
609 }
610 DataType::Time => {
611 query.bind(None as Option<chrono::NaiveTime>);
612 }
613 DataType::Timestamp => {
614 query.bind(None as Option<chrono::NaiveDateTime>);
615 }
616 DataType::Timestamptz => {
617 query.bind(None as Option<i64>);
618 }
619 DataType::Varchar => {
620 query.bind(None as Option<String>);
621 }
622 DataType::Bytea => {
623 query.bind(None as Option<Vec<u8>>);
624 }
625 DataType::Interval => return Err(data_type_not_supported("Interval")),
626 DataType::Struct(_) => return Err(data_type_not_supported("Struct")),
627 DataType::List(_) => return Err(data_type_not_supported("List")),
628 DataType::Jsonb => return Err(data_type_not_supported("Jsonb")),
629 DataType::Serial => return Err(data_type_not_supported("Serial")),
630 DataType::Int256 => return Err(data_type_not_supported("Int256")),
631 DataType::Map(_) => return Err(data_type_not_supported("Map")),
632 },
633 };
634 }
635 Ok(())
636}
637
638fn data_type_not_supported(data_type_name: &str) -> SinkError {
639 SinkError::SqlServer(anyhow!(format!(
640 "{data_type_name} is not supported in SQL Server"
641 )))
642}
643
644fn check_data_type_compatibility(data_type: &DataType) -> Result<()> {
645 match data_type {
646 DataType::Boolean
647 | DataType::Int16
648 | DataType::Int32
649 | DataType::Int64
650 | DataType::Float32
651 | DataType::Float64
652 | DataType::Decimal
653 | DataType::Date
654 | DataType::Varchar
655 | DataType::Time
656 | DataType::Timestamp
657 | DataType::Timestamptz
658 | DataType::Bytea => Ok(()),
659 DataType::Interval => Err(data_type_not_supported("Interval")),
660 DataType::Struct(_) => Err(data_type_not_supported("Struct")),
661 DataType::List(_) => Err(data_type_not_supported("List")),
662 DataType::Jsonb => Err(data_type_not_supported("Jsonb")),
663 DataType::Serial => Err(data_type_not_supported("Serial")),
664 DataType::Int256 => Err(data_type_not_supported("Int256")),
665 DataType::Map(_) => Err(data_type_not_supported("Map")),
666 }
667}
668
669fn decimal_to_sql(decimal: &rust_decimal::Decimal) -> Numeric {
671 let unpacked = decimal.unpack();
672
673 let mut value = (((unpacked.hi as u128) << 64)
674 + ((unpacked.mid as u128) << 32)
675 + unpacked.lo as u128) as i128;
676
677 if decimal.is_sign_negative() {
678 value = -value;
679 }
680
681 Numeric::new_with_scale(value, decimal.scale() as u8)
682}