risingwave_connector/source/cdc/external/
sql_server.rs1use std::cmp::Ordering;
16
17use anyhow::{Context, anyhow};
18use futures::stream::BoxStream;
19use futures::{StreamExt, TryStreamExt, pin_mut, stream};
20use futures_async_stream::try_stream;
21use itertools::Itertools;
22use risingwave_common::bail;
23use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
24use risingwave_common::row::OwnedRow;
25use risingwave_common::types::{DataType, ScalarImpl};
26use serde_derive::{Deserialize, Serialize};
27use tiberius::{Config, Query, QueryItem};
28
29use crate::error::{ConnectorError, ConnectorResult};
30use crate::parser::{ScalarImplTiberiusWrapper, sql_server_row_to_owned_row};
31use crate::sink::sqlserver::SqlServerClient;
32use crate::source::CdcTableSnapshotSplit;
33use crate::source::cdc::external::{
34 CdcOffset, CdcOffsetParseFunc, CdcTableSnapshotSplitOption, DebeziumOffset,
35 ExternalTableConfig, ExternalTableReader, SchemaTableName,
36};
37
38const MAX_COMMIT_LSN: &str = "ffffffff:ffffffff:ffff";
40
41#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
42pub struct SqlServerOffset {
43 pub change_lsn: String,
45 pub commit_lsn: String,
46}
47
48impl PartialOrd for SqlServerOffset {
50 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
51 match self.change_lsn.partial_cmp(&other.change_lsn) {
52 Some(Ordering::Equal) => self.commit_lsn.partial_cmp(&other.commit_lsn),
53 other => other,
54 }
55 }
56}
57
58impl SqlServerOffset {
59 pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
60 let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
61 .with_context(|| format!("invalid upstream offset: {}", offset))?;
62
63 Ok(Self {
64 change_lsn: dbz_offset
65 .source_offset
66 .change_lsn
67 .context("invalid sql server change_lsn")?,
68 commit_lsn: dbz_offset
69 .source_offset
70 .commit_lsn
71 .context("invalid sql server commit_lsn")?,
72 })
73 }
74}
75
76pub struct SqlServerExternalTable {
77 column_descs: Vec<ColumnDesc>,
78 pk_names: Vec<String>,
79}
80
81impl SqlServerExternalTable {
82 pub async fn connect(config: ExternalTableConfig) -> ConnectorResult<Self> {
83 tracing::debug!("connect to sql server");
84
85 let mut client_config = Config::new();
86
87 client_config.host(&config.host);
88 client_config.database(&config.database);
89 client_config.port(config.port.parse::<u16>().unwrap());
90 client_config.authentication(tiberius::AuthMethod::sql_server(
91 &config.username,
92 &config.password,
93 ));
94 if config.encrypt == "true" {
96 client_config.encryption(tiberius::EncryptionLevel::Required);
97 }
98 client_config.trust_cert();
99
100 let mut client = SqlServerClient::new_with_config(client_config).await?;
101
102 let mut column_descs = vec![];
103 let mut pk_names = vec![];
104 {
105 let sql = Query::new(format!(
106 "SELECT
107 COLUMN_NAME,
108 DATA_TYPE
109 FROM
110 INFORMATION_SCHEMA.COLUMNS
111 WHERE
112 TABLE_SCHEMA = '{}'
113 AND TABLE_NAME = '{}'",
114 config.schema.clone(),
115 config.table.clone(),
116 ));
117
118 let mut stream = sql.query(&mut client.inner_client).await?;
119 while let Some(item) = stream.try_next().await? {
120 match item {
121 QueryItem::Metadata(_) => {}
122 QueryItem::Row(row) => {
123 let col_name: &str = row.try_get(0)?.unwrap();
124 let col_type: &str = row.try_get(1)?.unwrap();
125 column_descs.push(ColumnDesc::named(
126 col_name,
127 ColumnId::placeholder(),
128 mssql_type_to_rw_type(col_type, col_name)?,
129 ));
130 }
131 }
132 }
133 }
134 {
135 let sql = Query::new(format!(
136 "SELECT kcu.COLUMN_NAME
137 FROM
138 INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
139 JOIN
140 INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS kcu
141 ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME AND
142 tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA AND
143 tc.TABLE_NAME = kcu.TABLE_NAME
144 WHERE
145 tc.CONSTRAINT_TYPE = 'PRIMARY KEY' AND
146 tc.TABLE_SCHEMA = '{}' AND tc.TABLE_NAME = '{}'",
147 config.schema, config.table,
148 ));
149
150 let mut stream = sql.query(&mut client.inner_client).await?;
151 while let Some(item) = stream.try_next().await? {
152 match item {
153 QueryItem::Metadata(_) => {}
154 QueryItem::Row(row) => {
155 let pk_name: &str = row.try_get(0)?.unwrap();
156 pk_names.push(pk_name.to_owned());
157 }
158 }
159 }
160 }
161
162 if column_descs.is_empty() {
164 bail!(
165 "Sql Server table '{}'.'{}' not found in '{}'",
166 config.schema,
167 config.table,
168 config.database
169 );
170 }
171
172 Ok(Self {
173 column_descs,
174 pk_names,
175 })
176 }
177
178 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
179 &self.column_descs
180 }
181
182 pub fn pk_names(&self) -> &Vec<String> {
183 &self.pk_names
184 }
185}
186
187fn mssql_type_to_rw_type(col_type: &str, col_name: &str) -> ConnectorResult<DataType> {
188 let dtype = match col_type.to_lowercase().as_str() {
189 "bit" => DataType::Boolean,
190 "binary" | "varbinary" => DataType::Bytea,
191 "tinyint" | "smallint" => DataType::Int16,
192 "integer" | "int" => DataType::Int32,
193 "bigint" => DataType::Int64,
194 "real" => DataType::Float32,
195 "float" => DataType::Float64,
196 "decimal" | "numeric" => DataType::Decimal,
197 "date" => DataType::Date,
198 "time" => DataType::Time,
199 "datetime" | "datetime2" | "smalldatetime" => DataType::Timestamp,
200 "datetimeoffset" => DataType::Timestamptz,
201 "char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext" | "xml"
202 | "uniqueidentifier" => DataType::Varchar,
203 "money" => DataType::Decimal,
204 mssql_type => {
205 return Err(anyhow!(
206 "Unsupported Sql Server data type: {:?}, column name: {}",
207 mssql_type,
208 col_name
209 )
210 .into());
211 }
212 };
213 Ok(dtype)
214}
215
216#[derive(Debug)]
217pub struct SqlServerExternalTableReader {
218 rw_schema: Schema,
219 field_names: String,
220 client: tokio::sync::Mutex<SqlServerClient>,
221}
222
223impl ExternalTableReader for SqlServerExternalTableReader {
224 async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
225 let mut client = self.client.lock().await;
226 let row = client
228 .inner_client
229 .simple_query(String::from("SELECT sys.fn_cdc_get_max_lsn()"))
230 .await?
231 .into_row()
232 .await?
233 .expect("No result returned by `SELECT sys.fn_cdc_get_max_lsn()`");
234 let max_lsn = match row.try_get::<&[u8], usize>(0)? {
237 Some(bytes) => {
238 let mut hex_string = String::with_capacity(bytes.len() * 2 + 2);
239 assert_eq!(
240 bytes.len(),
241 10,
242 "sys.fn_cdc_get_max_lsn() should return a 10 bytes array."
243 );
244 for byte in &bytes[0..4] {
245 hex_string.push_str(&format!("{:02x}", byte));
246 }
247 hex_string.push(':');
248 for byte in &bytes[4..8] {
249 hex_string.push_str(&format!("{:02x}", byte));
250 }
251 hex_string.push(':');
252 for byte in &bytes[8..10] {
253 hex_string.push_str(&format!("{:02x}", byte));
254 }
255 hex_string
256 }
257 None => bail!(
258 "None is returned by `SELECT sys.fn_cdc_get_max_lsn()`, please ensure Sql Server Agent is running."
259 ),
260 };
261
262 tracing::debug!("current max_lsn: {}", max_lsn);
263
264 Ok(CdcOffset::SqlServer(SqlServerOffset {
265 change_lsn: max_lsn,
266 commit_lsn: MAX_COMMIT_LSN.into(),
267 }))
268 }
269
270 fn snapshot_read(
271 &self,
272 table_name: SchemaTableName,
273 start_pk: Option<OwnedRow>,
274 primary_keys: Vec<String>,
275 limit: u32,
276 ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
277 self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
278 }
279
280 fn get_parallel_cdc_splits(
281 &self,
282 _options: CdcTableSnapshotSplitOption,
283 ) -> BoxStream<'_, ConnectorResult<CdcTableSnapshotSplit>> {
284 stream::empty::<ConnectorResult<CdcTableSnapshotSplit>>().boxed()
286 }
287
288 fn split_snapshot_read(
289 &self,
290 _table_name: SchemaTableName,
291 _left: OwnedRow,
292 _right: OwnedRow,
293 _split_columns: Vec<Field>,
294 ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
295 todo!("implement SqlServer CDC parallelized backfill")
296 }
297}
298
299impl SqlServerExternalTableReader {
300 pub async fn new(
301 config: ExternalTableConfig,
302 rw_schema: Schema,
303 pk_indices: Vec<usize>,
304 ) -> ConnectorResult<Self> {
305 tracing::info!(
306 ?rw_schema,
307 ?pk_indices,
308 "create sql server external table reader"
309 );
310 let mut client_config = Config::new();
311
312 client_config.host(&config.host);
313 client_config.database(&config.database);
314 client_config.port(config.port.parse::<u16>().unwrap());
315 client_config.authentication(tiberius::AuthMethod::sql_server(
316 &config.username,
317 &config.password,
318 ));
319 if config.encrypt == "true" {
321 client_config.encryption(tiberius::EncryptionLevel::Required);
322 }
323 client_config.trust_cert();
324
325 let client = SqlServerClient::new_with_config(client_config).await?;
326
327 let field_names = rw_schema
328 .fields
329 .iter()
330 .map(|f| Self::quote_column(&f.name))
331 .join(",");
332
333 Ok(Self {
334 rw_schema,
335 field_names,
336 client: tokio::sync::Mutex::new(client),
337 })
338 }
339
340 pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
341 Box::new(move |offset| {
342 Ok(CdcOffset::SqlServer(
343 SqlServerOffset::parse_debezium_offset(offset)?,
344 ))
345 })
346 }
347
348 #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
349 async fn snapshot_read_inner(
350 &self,
351 table_name: SchemaTableName,
352 start_pk_row: Option<OwnedRow>,
353 primary_keys: Vec<String>,
354 limit: u32,
355 ) {
356 let order_key = primary_keys
357 .iter()
358 .map(|col| Self::quote_column(col))
359 .join(",");
360 let mut sql = Query::new(if start_pk_row.is_none() {
361 format!(
362 "SELECT {} FROM {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
363 self.field_names,
364 Self::get_normalized_table_name(&table_name),
365 order_key,
366 )
367 } else {
368 let filter_expr = Self::filter_expression(&primary_keys);
369 format!(
370 "SELECT {} FROM {} WHERE {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
371 self.field_names,
372 Self::get_normalized_table_name(&table_name),
373 filter_expr,
374 order_key,
375 )
376 });
377
378 let mut client = self.client.lock().await;
379
380 if let Some(pk_row) = start_pk_row {
382 let params: Vec<Option<ScalarImpl>> = pk_row.into_iter().collect();
383 for param in params {
384 sql.bind(ScalarImplTiberiusWrapper::from(param.unwrap()));
386 }
387 }
388
389 let stream = sql.query(&mut client.inner_client).await?.into_row_stream();
390
391 let row_stream = stream.map(|res| {
392 let mut row = res?;
394 Ok::<_, ConnectorError>(sql_server_row_to_owned_row(&mut row, &self.rw_schema))
395 });
396
397 pin_mut!(row_stream);
398
399 #[for_await]
400 for row in row_stream {
401 let row = row?;
402 yield row;
403 }
404 }
405
406 pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
407 format!(
408 "\"{}\".\"{}\"",
409 table_name.schema_name, table_name.table_name
410 )
411 }
412
413 fn filter_expression(columns: &[String]) -> String {
417 let mut conditions = vec![];
418 conditions.push(format!("({} > @P{})", Self::quote_column(&columns[0]), 1));
420 for i in 2..=columns.len() {
421 let mut condition = String::new();
423 for (j, col) in columns.iter().enumerate().take(i - 1) {
424 if j == 0 {
425 condition.push_str(&format!("({} = @P{})", Self::quote_column(col), j + 1));
426 } else {
427 condition.push_str(&format!(
428 " AND ({} = @P{})",
429 Self::quote_column(col),
430 j + 1
431 ));
432 }
433 }
434 condition.push_str(&format!(
436 " AND ({} > @P{})",
437 Self::quote_column(&columns[i - 1]),
438 i
439 ));
440 conditions.push(format!("({})", condition));
441 }
442 if columns.len() > 1 {
443 conditions.join(" OR ")
444 } else {
445 conditions.join("")
446 }
447 }
448
449 fn quote_column(column: &str) -> String {
450 format!("\"{}\"", column)
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use crate::source::cdc::external::SqlServerExternalTableReader;
457
458 #[test]
459 fn test_sql_server_filter_expr() {
460 let cols = vec!["id".to_owned()];
461 let expr = SqlServerExternalTableReader::filter_expression(&cols);
462 assert_eq!(expr, "(\"id\" > @P1)");
463
464 let cols = vec!["aa".to_owned(), "bb".to_owned(), "cc".to_owned()];
465 let expr = SqlServerExternalTableReader::filter_expression(&cols);
466 assert_eq!(
467 expr,
468 "(\"aa\" > @P1) OR ((\"aa\" = @P1) AND (\"bb\" > @P2)) OR ((\"aa\" = @P1) AND (\"bb\" = @P2) AND (\"cc\" > @P3))"
469 );
470 }
471}