risingwave_connector/source/cdc/external/
sql_server.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16
17use anyhow::{Context, anyhow};
18use futures::stream::BoxStream;
19use futures::{StreamExt, TryStreamExt, pin_mut};
20use futures_async_stream::try_stream;
21use itertools::Itertools;
22use risingwave_common::bail;
23use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema};
24use risingwave_common::row::OwnedRow;
25use risingwave_common::types::{DataType, ScalarImpl};
26use serde_derive::{Deserialize, Serialize};
27use tiberius::{Config, Query, QueryItem};
28
29use crate::error::{ConnectorError, ConnectorResult};
30use crate::parser::{ScalarImplTiberiusWrapper, sql_server_row_to_owned_row};
31use crate::sink::sqlserver::SqlServerClient;
32use crate::source::cdc::external::{
33    CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader,
34    SchemaTableName,
35};
36
37// The maximum commit_lsn value in Sql Server
38const MAX_COMMIT_LSN: &str = "ffffffff:ffffffff:ffff";
39
40#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
41pub struct SqlServerOffset {
42    // https://learn.microsoft.com/en-us/answers/questions/1328359/how-to-accurately-sequence-change-data-capture-dat
43    pub change_lsn: String,
44    pub commit_lsn: String,
45}
46
47// only compare the lsn field
48impl PartialOrd for SqlServerOffset {
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        match self.change_lsn.partial_cmp(&other.change_lsn) {
51            Some(Ordering::Equal) => self.commit_lsn.partial_cmp(&other.commit_lsn),
52            other => other,
53        }
54    }
55}
56
57impl SqlServerOffset {
58    pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
59        let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
60            .with_context(|| format!("invalid upstream offset: {}", offset))?;
61
62        Ok(Self {
63            change_lsn: dbz_offset
64                .source_offset
65                .change_lsn
66                .context("invalid sql server change_lsn")?,
67            commit_lsn: dbz_offset
68                .source_offset
69                .commit_lsn
70                .context("invalid sql server commit_lsn")?,
71        })
72    }
73}
74
75pub struct SqlServerExternalTable {
76    column_descs: Vec<ColumnDesc>,
77    pk_names: Vec<String>,
78}
79
80impl SqlServerExternalTable {
81    pub async fn connect(config: ExternalTableConfig) -> ConnectorResult<Self> {
82        tracing::debug!("connect to sql server");
83
84        let mut client_config = Config::new();
85
86        client_config.host(&config.host);
87        client_config.database(&config.database);
88        client_config.port(config.port.parse::<u16>().unwrap());
89        client_config.authentication(tiberius::AuthMethod::sql_server(
90            &config.username,
91            &config.password,
92        ));
93        // TODO(kexiang): use trust_cert_ca, trust_cert is not secure
94        if config.encrypt == "true" {
95            client_config.encryption(tiberius::EncryptionLevel::Required);
96        }
97        client_config.trust_cert();
98
99        let mut client = SqlServerClient::new_with_config(client_config).await?;
100
101        let mut column_descs = vec![];
102        let mut pk_names = vec![];
103        {
104            let sql = Query::new(format!(
105                "SELECT
106                    COLUMN_NAME,
107                    DATA_TYPE
108                FROM
109                    INFORMATION_SCHEMA.COLUMNS
110                WHERE
111                    TABLE_SCHEMA = '{}'
112                    AND TABLE_NAME = '{}'",
113                config.schema.clone(),
114                config.table.clone(),
115            ));
116
117            let mut stream = sql.query(&mut client.inner_client).await?;
118            while let Some(item) = stream.try_next().await? {
119                match item {
120                    QueryItem::Metadata(_) => {}
121                    QueryItem::Row(row) => {
122                        let col_name: &str = row.try_get(0)?.unwrap();
123                        let col_type: &str = row.try_get(1)?.unwrap();
124                        column_descs.push(ColumnDesc::named(
125                            col_name,
126                            ColumnId::placeholder(),
127                            mssql_type_to_rw_type(col_type, col_name)?,
128                        ));
129                    }
130                }
131            }
132        }
133        {
134            let sql = Query::new(format!(
135                "SELECT kcu.COLUMN_NAME
136                FROM
137                    INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
138                JOIN
139                    INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS kcu
140                    ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME AND
141                    tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA AND
142                    tc.TABLE_NAME = kcu.TABLE_NAME
143                WHERE
144                    tc.CONSTRAINT_TYPE = 'PRIMARY KEY' AND
145                    tc.TABLE_SCHEMA = '{}' AND tc.TABLE_NAME = '{}'",
146                config.schema, config.table,
147            ));
148
149            let mut stream = sql.query(&mut client.inner_client).await?;
150            while let Some(item) = stream.try_next().await? {
151                match item {
152                    QueryItem::Metadata(_) => {}
153                    QueryItem::Row(row) => {
154                        let pk_name: &str = row.try_get(0)?.unwrap();
155                        pk_names.push(pk_name.to_owned());
156                    }
157                }
158            }
159        }
160
161        // The table does not exist
162        if column_descs.is_empty() {
163            bail!(
164                "Sql Server table '{}'.'{}' not found in '{}'",
165                config.schema,
166                config.table,
167                config.database
168            );
169        }
170
171        Ok(Self {
172            column_descs,
173            pk_names,
174        })
175    }
176
177    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
178        &self.column_descs
179    }
180
181    pub fn pk_names(&self) -> &Vec<String> {
182        &self.pk_names
183    }
184}
185
186fn mssql_type_to_rw_type(col_type: &str, col_name: &str) -> ConnectorResult<DataType> {
187    let dtype = match col_type.to_lowercase().as_str() {
188        "bit" => DataType::Boolean,
189        "binary" | "varbinary" => DataType::Bytea,
190        "tinyint" | "smallint" => DataType::Int16,
191        "integer" | "int" => DataType::Int32,
192        "bigint" => DataType::Int64,
193        "real" => DataType::Float32,
194        "float" => DataType::Float64,
195        "decimal" | "numeric" => DataType::Decimal,
196        "date" => DataType::Date,
197        "time" => DataType::Time,
198        "datetime" | "datetime2" | "smalldatetime" => DataType::Timestamp,
199        "datetimeoffset" => DataType::Timestamptz,
200        "char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext" | "xml"
201        | "uniqueidentifier" => DataType::Varchar,
202        mssql_type => {
203            return Err(anyhow!(
204                "Unsupported Sql Server data type: {:?}, column name: {}",
205                mssql_type,
206                col_name
207            )
208            .into());
209        }
210    };
211    Ok(dtype)
212}
213
214#[derive(Debug)]
215pub struct SqlServerExternalTableReader {
216    rw_schema: Schema,
217    field_names: String,
218    client: tokio::sync::Mutex<SqlServerClient>,
219}
220
221impl ExternalTableReader for SqlServerExternalTableReader {
222    async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
223        let mut client = self.client.lock().await;
224        // start a transaction to read max start_lsn.
225        let row = client
226            .inner_client
227            .simple_query(String::from("SELECT sys.fn_cdc_get_max_lsn()"))
228            .await?
229            .into_row()
230            .await?
231            .expect("No result returned by `SELECT sys.fn_cdc_get_max_lsn()`");
232        // An example of change_lsn or commit_lsn: "00000027:00000ac0:0002" from debezium
233        // sys.fn_cdc_get_max_lsn() returns a 10 bytes array, we convert it to a hex string here.
234        let max_lsn = match row.try_get::<&[u8], usize>(0)? {
235            Some(bytes) => {
236                let mut hex_string = String::with_capacity(bytes.len() * 2 + 2);
237                assert_eq!(
238                    bytes.len(),
239                    10,
240                    "sys.fn_cdc_get_max_lsn() should return a 10 bytes array."
241                );
242                for byte in &bytes[0..4] {
243                    hex_string.push_str(&format!("{:02x}", byte));
244                }
245                hex_string.push(':');
246                for byte in &bytes[4..8] {
247                    hex_string.push_str(&format!("{:02x}", byte));
248                }
249                hex_string.push(':');
250                for byte in &bytes[8..10] {
251                    hex_string.push_str(&format!("{:02x}", byte));
252                }
253                hex_string
254            }
255            None => bail!(
256                "None is returned by `SELECT sys.fn_cdc_get_max_lsn()`, please ensure Sql Server Agent is running."
257            ),
258        };
259
260        tracing::debug!("current max_lsn: {}", max_lsn);
261
262        Ok(CdcOffset::SqlServer(SqlServerOffset {
263            change_lsn: max_lsn,
264            commit_lsn: MAX_COMMIT_LSN.into(),
265        }))
266    }
267
268    fn snapshot_read(
269        &self,
270        table_name: SchemaTableName,
271        start_pk: Option<OwnedRow>,
272        primary_keys: Vec<String>,
273        limit: u32,
274    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
275        self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
276    }
277}
278
279impl SqlServerExternalTableReader {
280    pub async fn new(
281        config: ExternalTableConfig,
282        rw_schema: Schema,
283        pk_indices: Vec<usize>,
284    ) -> ConnectorResult<Self> {
285        tracing::info!(
286            ?rw_schema,
287            ?pk_indices,
288            "create sql server external table reader"
289        );
290        let mut client_config = Config::new();
291
292        client_config.host(&config.host);
293        client_config.database(&config.database);
294        client_config.port(config.port.parse::<u16>().unwrap());
295        client_config.authentication(tiberius::AuthMethod::sql_server(
296            &config.username,
297            &config.password,
298        ));
299        // TODO(kexiang): use trust_cert_ca, trust_cert is not secure
300        if config.encrypt == "true" {
301            client_config.encryption(tiberius::EncryptionLevel::Required);
302        }
303        client_config.trust_cert();
304
305        let client = SqlServerClient::new_with_config(client_config).await?;
306
307        let field_names = rw_schema
308            .fields
309            .iter()
310            .map(|f| Self::quote_column(&f.name))
311            .join(",");
312
313        Ok(Self {
314            rw_schema,
315            field_names,
316            client: tokio::sync::Mutex::new(client),
317        })
318    }
319
320    pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
321        Box::new(move |offset| {
322            Ok(CdcOffset::SqlServer(
323                SqlServerOffset::parse_debezium_offset(offset)?,
324            ))
325        })
326    }
327
328    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
329    async fn snapshot_read_inner(
330        &self,
331        table_name: SchemaTableName,
332        start_pk_row: Option<OwnedRow>,
333        primary_keys: Vec<String>,
334        limit: u32,
335    ) {
336        let order_key = primary_keys
337            .iter()
338            .map(|col| Self::quote_column(col))
339            .join(",");
340        let mut sql = Query::new(if start_pk_row.is_none() {
341            format!(
342                "SELECT {} FROM {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
343                self.field_names,
344                Self::get_normalized_table_name(&table_name),
345                order_key,
346            )
347        } else {
348            let filter_expr = Self::filter_expression(&primary_keys);
349            format!(
350                "SELECT {} FROM {} WHERE {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
351                self.field_names,
352                Self::get_normalized_table_name(&table_name),
353                filter_expr,
354                order_key,
355            )
356        });
357
358        let mut client = self.client.lock().await;
359
360        // FIXME(kexiang): Set session timezone to UTC
361        if let Some(pk_row) = start_pk_row {
362            let params: Vec<Option<ScalarImpl>> = pk_row.into_iter().collect();
363            for param in params {
364                // primary key should not be null, so it's safe to unwrap
365                sql.bind(ScalarImplTiberiusWrapper::from(param.unwrap()));
366            }
367        }
368
369        let stream = sql.query(&mut client.inner_client).await?.into_row_stream();
370
371        let row_stream = stream.map(|res| {
372            // convert sql server row into OwnedRow
373            let mut row = res?;
374            Ok::<_, ConnectorError>(sql_server_row_to_owned_row(&mut row, &self.rw_schema))
375        });
376
377        pin_mut!(row_stream);
378
379        #[for_await]
380        for row in row_stream {
381            let row = row?;
382            yield row;
383        }
384    }
385
386    pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
387        format!(
388            "\"{}\".\"{}\"",
389            table_name.schema_name, table_name.table_name
390        )
391    }
392
393    // sql server cannot leverage the given key to narrow down the range of scan,
394    // we need to rewrite the comparison conditions by our own.
395    // (a, b) > (x, y) => ("a" > @P1) OR (("a" = @P1) AND ("b" > @P2))
396    fn filter_expression(columns: &[String]) -> String {
397        let mut conditions = vec![];
398        // push the first condition
399        conditions.push(format!("({} > @P{})", Self::quote_column(&columns[0]), 1));
400        for i in 2..=columns.len() {
401            // '=' condition
402            let mut condition = String::new();
403            for (j, col) in columns.iter().enumerate().take(i - 1) {
404                if j == 0 {
405                    condition.push_str(&format!("({} = @P{})", Self::quote_column(col), j + 1));
406                } else {
407                    condition.push_str(&format!(
408                        " AND ({} = @P{})",
409                        Self::quote_column(col),
410                        j + 1
411                    ));
412                }
413            }
414            // '>' condition
415            condition.push_str(&format!(
416                " AND ({} > @P{})",
417                Self::quote_column(&columns[i - 1]),
418                i
419            ));
420            conditions.push(format!("({})", condition));
421        }
422        if columns.len() > 1 {
423            conditions.join(" OR ")
424        } else {
425            conditions.join("")
426        }
427    }
428
429    fn quote_column(column: &str) -> String {
430        format!("\"{}\"", column)
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use crate::source::cdc::external::SqlServerExternalTableReader;
437
438    #[test]
439    fn test_sql_server_filter_expr() {
440        let cols = vec!["id".to_owned()];
441        let expr = SqlServerExternalTableReader::filter_expression(&cols);
442        assert_eq!(expr, "(\"id\" > @P1)");
443
444        let cols = vec!["aa".to_owned(), "bb".to_owned(), "cc".to_owned()];
445        let expr = SqlServerExternalTableReader::filter_expression(&cols);
446        assert_eq!(
447            expr,
448            "(\"aa\" > @P1) OR ((\"aa\" = @P1) AND (\"bb\" > @P2)) OR ((\"aa\" = @P1) AND (\"bb\" = @P2) AND (\"cc\" > @P3))"
449        );
450    }
451}