risingwave_connector/source/cdc/external/
sql_server.rs1use std::cmp::Ordering;
16
17use anyhow::{Context, anyhow};
18use futures::stream::BoxStream;
19use futures::{StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use itertools::Itertools;
22use risingwave_common::bail;
23use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema};
24use risingwave_common::row::OwnedRow;
25use risingwave_common::types::{DataType, ScalarImpl};
26use serde_derive::{Deserialize, Serialize};
27use tiberius::{Config, Query, QueryItem};
28
29use crate::error::{ConnectorError, ConnectorResult};
30use crate::parser::{ScalarImplTiberiusWrapper, sql_server_row_to_owned_row};
31use crate::sink::sqlserver::SqlServerClient;
32use crate::source::cdc::external::{
33 CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader,
34 SchemaTableName,
35};
36
37const MAX_COMMIT_LSN: &str = "ffffffff:ffffffff:ffff";
39
40#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
41pub struct SqlServerOffset {
42 pub change_lsn: String,
44 pub commit_lsn: String,
45}
46
47impl PartialOrd for SqlServerOffset {
49 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50 match self.change_lsn.partial_cmp(&other.change_lsn) {
51 Some(Ordering::Equal) => self.commit_lsn.partial_cmp(&other.commit_lsn),
52 other => other,
53 }
54 }
55}
56
57impl SqlServerOffset {
58 pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
59 let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
60 .with_context(|| format!("invalid upstream offset: {}", offset))?;
61
62 Ok(Self {
63 change_lsn: dbz_offset
64 .source_offset
65 .change_lsn
66 .context("invalid sql server change_lsn")?,
67 commit_lsn: dbz_offset
68 .source_offset
69 .commit_lsn
70 .context("invalid sql server commit_lsn")?,
71 })
72 }
73}
74
75pub struct SqlServerExternalTable {
76 column_descs: Vec<ColumnDesc>,
77 pk_names: Vec<String>,
78}
79
80impl SqlServerExternalTable {
81 pub async fn connect(config: ExternalTableConfig) -> ConnectorResult<Self> {
82 tracing::debug!("connect to sql server");
83
84 let mut client_config = Config::new();
85
86 client_config.host(&config.host);
87 client_config.database(&config.database);
88 client_config.port(config.port.parse::<u16>().unwrap());
89 client_config.authentication(tiberius::AuthMethod::sql_server(
90 &config.username,
91 &config.password,
92 ));
93 if config.encrypt == "true" {
95 client_config.encryption(tiberius::EncryptionLevel::Required);
96 }
97 client_config.trust_cert();
98
99 let mut client = SqlServerClient::new_with_config(client_config).await?;
100
101 let mut column_descs = vec![];
102 let mut pk_names = vec![];
103 {
104 let sql = Query::new(format!(
105 "SELECT
106 COLUMN_NAME,
107 DATA_TYPE
108 FROM
109 INFORMATION_SCHEMA.COLUMNS
110 WHERE
111 TABLE_SCHEMA = '{}'
112 AND TABLE_NAME = '{}'",
113 config.schema.clone(),
114 config.table.clone(),
115 ));
116
117 let mut stream = sql.query(&mut client.inner_client).await?;
118 while let Some(item) = stream.try_next().await? {
119 match item {
120 QueryItem::Metadata(_) => {}
121 QueryItem::Row(row) => {
122 let col_name: &str = row.try_get(0)?.unwrap();
123 let col_type: &str = row.try_get(1)?.unwrap();
124 column_descs.push(ColumnDesc::named(
125 col_name,
126 ColumnId::placeholder(),
127 mssql_type_to_rw_type(col_type, col_name)?,
128 ));
129 }
130 }
131 }
132 }
133 {
134 let sql = Query::new(format!(
135 "SELECT kcu.COLUMN_NAME
136 FROM
137 INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
138 JOIN
139 INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS kcu
140 ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME AND
141 tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA AND
142 tc.TABLE_NAME = kcu.TABLE_NAME
143 WHERE
144 tc.CONSTRAINT_TYPE = 'PRIMARY KEY' AND
145 tc.TABLE_SCHEMA = '{}' AND tc.TABLE_NAME = '{}'",
146 config.schema, config.table,
147 ));
148
149 let mut stream = sql.query(&mut client.inner_client).await?;
150 while let Some(item) = stream.try_next().await? {
151 match item {
152 QueryItem::Metadata(_) => {}
153 QueryItem::Row(row) => {
154 let pk_name: &str = row.try_get(0)?.unwrap();
155 pk_names.push(pk_name.to_owned());
156 }
157 }
158 }
159 }
160
161 if column_descs.is_empty() {
163 bail!(
164 "Sql Server table '{}'.'{}' not found in '{}'",
165 config.schema,
166 config.table,
167 config.database
168 );
169 }
170
171 Ok(Self {
172 column_descs,
173 pk_names,
174 })
175 }
176
177 pub fn column_descs(&self) -> &Vec<ColumnDesc> {
178 &self.column_descs
179 }
180
181 pub fn pk_names(&self) -> &Vec<String> {
182 &self.pk_names
183 }
184}
185
186fn mssql_type_to_rw_type(col_type: &str, col_name: &str) -> ConnectorResult<DataType> {
187 let dtype = match col_type.to_lowercase().as_str() {
188 "bit" => DataType::Boolean,
189 "binary" | "varbinary" => DataType::Bytea,
190 "tinyint" | "smallint" => DataType::Int16,
191 "integer" | "int" => DataType::Int32,
192 "bigint" => DataType::Int64,
193 "real" => DataType::Float32,
194 "float" => DataType::Float64,
195 "decimal" | "numeric" => DataType::Decimal,
196 "date" => DataType::Date,
197 "time" => DataType::Time,
198 "datetime" | "datetime2" | "smalldatetime" => DataType::Timestamp,
199 "datetimeoffset" => DataType::Timestamptz,
200 "char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext" | "xml"
201 | "uniqueidentifier" => DataType::Varchar,
202 "money" => DataType::Decimal,
203 mssql_type => {
204 return Err(anyhow!(
205 "Unsupported Sql Server data type: {:?}, column name: {}",
206 mssql_type,
207 col_name
208 )
209 .into());
210 }
211 };
212 Ok(dtype)
213}
214
215#[derive(Debug)]
216pub struct SqlServerExternalTableReader {
217 rw_schema: Schema,
218 field_names: String,
219 client: tokio::sync::Mutex<SqlServerClient>,
220}
221
222impl ExternalTableReader for SqlServerExternalTableReader {
223 async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
224 let mut client = self.client.lock().await;
225 let row = client
227 .inner_client
228 .simple_query(String::from("SELECT sys.fn_cdc_get_max_lsn()"))
229 .await?
230 .into_row()
231 .await?
232 .expect("No result returned by `SELECT sys.fn_cdc_get_max_lsn()`");
233 let max_lsn = match row.try_get::<&[u8], usize>(0)? {
236 Some(bytes) => {
237 let mut hex_string = String::with_capacity(bytes.len() * 2 + 2);
238 assert_eq!(
239 bytes.len(),
240 10,
241 "sys.fn_cdc_get_max_lsn() should return a 10 bytes array."
242 );
243 for byte in &bytes[0..4] {
244 hex_string.push_str(&format!("{:02x}", byte));
245 }
246 hex_string.push(':');
247 for byte in &bytes[4..8] {
248 hex_string.push_str(&format!("{:02x}", byte));
249 }
250 hex_string.push(':');
251 for byte in &bytes[8..10] {
252 hex_string.push_str(&format!("{:02x}", byte));
253 }
254 hex_string
255 }
256 None => bail!(
257 "None is returned by `SELECT sys.fn_cdc_get_max_lsn()`, please ensure Sql Server Agent is running."
258 ),
259 };
260
261 tracing::debug!("current max_lsn: {}", max_lsn);
262
263 Ok(CdcOffset::SqlServer(SqlServerOffset {
264 change_lsn: max_lsn,
265 commit_lsn: MAX_COMMIT_LSN.into(),
266 }))
267 }
268
269 fn snapshot_read(
270 &self,
271 table_name: SchemaTableName,
272 start_pk: Option<OwnedRow>,
273 primary_keys: Vec<String>,
274 limit: u32,
275 ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
276 self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
277 }
278}
279
280impl SqlServerExternalTableReader {
281 pub async fn new(
282 config: ExternalTableConfig,
283 rw_schema: Schema,
284 pk_indices: Vec<usize>,
285 ) -> ConnectorResult<Self> {
286 tracing::info!(
287 ?rw_schema,
288 ?pk_indices,
289 "create sql server external table reader"
290 );
291 let mut client_config = Config::new();
292
293 client_config.host(&config.host);
294 client_config.database(&config.database);
295 client_config.port(config.port.parse::<u16>().unwrap());
296 client_config.authentication(tiberius::AuthMethod::sql_server(
297 &config.username,
298 &config.password,
299 ));
300 if config.encrypt == "true" {
302 client_config.encryption(tiberius::EncryptionLevel::Required);
303 }
304 client_config.trust_cert();
305
306 let client = SqlServerClient::new_with_config(client_config).await?;
307
308 let field_names = rw_schema
309 .fields
310 .iter()
311 .map(|f| Self::quote_column(&f.name))
312 .join(",");
313
314 Ok(Self {
315 rw_schema,
316 field_names,
317 client: tokio::sync::Mutex::new(client),
318 })
319 }
320
321 pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
322 Box::new(move |offset| {
323 Ok(CdcOffset::SqlServer(
324 SqlServerOffset::parse_debezium_offset(offset)?,
325 ))
326 })
327 }
328
329 #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
330 async fn snapshot_read_inner(
331 &self,
332 table_name: SchemaTableName,
333 start_pk_row: Option<OwnedRow>,
334 primary_keys: Vec<String>,
335 limit: u32,
336 ) {
337 let order_key = primary_keys
338 .iter()
339 .map(|col| Self::quote_column(col))
340 .join(",");
341 let mut sql = Query::new(if start_pk_row.is_none() {
342 format!(
343 "SELECT {} FROM {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
344 self.field_names,
345 Self::get_normalized_table_name(&table_name),
346 order_key,
347 )
348 } else {
349 let filter_expr = Self::filter_expression(&primary_keys);
350 format!(
351 "SELECT {} FROM {} WHERE {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
352 self.field_names,
353 Self::get_normalized_table_name(&table_name),
354 filter_expr,
355 order_key,
356 )
357 });
358
359 let mut client = self.client.lock().await;
360
361 if let Some(pk_row) = start_pk_row {
363 let params: Vec<Option<ScalarImpl>> = pk_row.into_iter().collect();
364 for param in params {
365 sql.bind(ScalarImplTiberiusWrapper::from(param.unwrap()));
367 }
368 }
369
370 let stream = sql.query(&mut client.inner_client).await?.into_row_stream();
371
372 let row_stream = stream.map(|res| {
373 let mut row = res?;
375 Ok::<_, ConnectorError>(sql_server_row_to_owned_row(&mut row, &self.rw_schema))
376 });
377
378 pin_mut!(row_stream);
379
380 #[for_await]
381 for row in row_stream {
382 let row = row?;
383 yield row;
384 }
385 }
386
387 pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
388 format!(
389 "\"{}\".\"{}\"",
390 table_name.schema_name, table_name.table_name
391 )
392 }
393
394 fn filter_expression(columns: &[String]) -> String {
398 let mut conditions = vec![];
399 conditions.push(format!("({} > @P{})", Self::quote_column(&columns[0]), 1));
401 for i in 2..=columns.len() {
402 let mut condition = String::new();
404 for (j, col) in columns.iter().enumerate().take(i - 1) {
405 if j == 0 {
406 condition.push_str(&format!("({} = @P{})", Self::quote_column(col), j + 1));
407 } else {
408 condition.push_str(&format!(
409 " AND ({} = @P{})",
410 Self::quote_column(col),
411 j + 1
412 ));
413 }
414 }
415 condition.push_str(&format!(
417 " AND ({} > @P{})",
418 Self::quote_column(&columns[i - 1]),
419 i
420 ));
421 conditions.push(format!("({})", condition));
422 }
423 if columns.len() > 1 {
424 conditions.join(" OR ")
425 } else {
426 conditions.join("")
427 }
428 }
429
430 fn quote_column(column: &str) -> String {
431 format!("\"{}\"", column)
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use crate::source::cdc::external::SqlServerExternalTableReader;
438
439 #[test]
440 fn test_sql_server_filter_expr() {
441 let cols = vec!["id".to_owned()];
442 let expr = SqlServerExternalTableReader::filter_expression(&cols);
443 assert_eq!(expr, "(\"id\" > @P1)");
444
445 let cols = vec!["aa".to_owned(), "bb".to_owned(), "cc".to_owned()];
446 let expr = SqlServerExternalTableReader::filter_expression(&cols);
447 assert_eq!(
448 expr,
449 "(\"aa\" > @P1) OR ((\"aa\" = @P1) AND (\"bb\" > @P2)) OR ((\"aa\" = @P1) AND (\"bb\" = @P2) AND (\"cc\" > @P3))"
450 );
451 }
452}