risingwave_connector/source/cdc/external/
sql_server.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16
17use anyhow::{Context, anyhow};
18use futures::stream::BoxStream;
19use futures::{StreamExt, TryStreamExt, pin_mut, stream};
20use futures_async_stream::try_stream;
21use itertools::Itertools;
22use risingwave_common::bail;
23use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
24use risingwave_common::row::OwnedRow;
25use risingwave_common::types::{DataType, ScalarImpl};
26use serde_derive::{Deserialize, Serialize};
27use tiberius::{Config, Query, QueryItem};
28
29use crate::error::{ConnectorError, ConnectorResult};
30use crate::parser::{ScalarImplTiberiusWrapper, sql_server_row_to_owned_row};
31use crate::sink::sqlserver::SqlServerClient;
32use crate::source::CdcTableSnapshotSplit;
33use crate::source::cdc::external::{
34    CdcOffset, CdcOffsetParseFunc, CdcTableSnapshotSplitOption, DebeziumOffset,
35    ExternalTableConfig, ExternalTableReader, SchemaTableName,
36};
37
38// The maximum commit_lsn value in Sql Server
39const MAX_COMMIT_LSN: &str = "ffffffff:ffffffff:ffff";
40
41#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
42pub struct SqlServerOffset {
43    // https://learn.microsoft.com/en-us/answers/questions/1328359/how-to-accurately-sequence-change-data-capture-dat
44    pub change_lsn: String,
45    pub commit_lsn: String,
46}
47
48// only compare the lsn field
49impl PartialOrd for SqlServerOffset {
50    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
51        match self.change_lsn.partial_cmp(&other.change_lsn) {
52            Some(Ordering::Equal) => self.commit_lsn.partial_cmp(&other.commit_lsn),
53            other => other,
54        }
55    }
56}
57
58impl SqlServerOffset {
59    pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
60        let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
61            .with_context(|| format!("invalid upstream offset: {}", offset))?;
62
63        Ok(Self {
64            change_lsn: dbz_offset
65                .source_offset
66                .change_lsn
67                .context("invalid sql server change_lsn")?,
68            commit_lsn: dbz_offset
69                .source_offset
70                .commit_lsn
71                .context("invalid sql server commit_lsn")?,
72        })
73    }
74}
75
76pub struct SqlServerExternalTable {
77    column_descs: Vec<ColumnDesc>,
78    pk_names: Vec<String>,
79}
80
81impl SqlServerExternalTable {
82    pub async fn connect(config: ExternalTableConfig) -> ConnectorResult<Self> {
83        tracing::debug!("connect to sql server");
84
85        let mut client_config = Config::new();
86
87        client_config.host(&config.host);
88        client_config.database(&config.database);
89        client_config.port(config.port.parse::<u16>().unwrap());
90        client_config.authentication(tiberius::AuthMethod::sql_server(
91            &config.username,
92            &config.password,
93        ));
94        // TODO(kexiang): use trust_cert_ca, trust_cert is not secure
95        if config.encrypt == "true" {
96            client_config.encryption(tiberius::EncryptionLevel::Required);
97        }
98        client_config.trust_cert();
99
100        let mut client = SqlServerClient::new_with_config(client_config).await?;
101
102        let mut column_descs = vec![];
103        let mut pk_names = vec![];
104        {
105            let sql = Query::new(format!(
106                "SELECT
107                    COLUMN_NAME,
108                    DATA_TYPE
109                FROM
110                    INFORMATION_SCHEMA.COLUMNS
111                WHERE
112                    TABLE_SCHEMA = '{}'
113                    AND TABLE_NAME = '{}'",
114                config.schema.clone(),
115                config.table.clone(),
116            ));
117
118            let mut stream = sql.query(&mut client.inner_client).await?;
119            while let Some(item) = stream.try_next().await? {
120                match item {
121                    QueryItem::Metadata(_) => {}
122                    QueryItem::Row(row) => {
123                        let col_name: &str = row.try_get(0)?.unwrap();
124                        let col_type: &str = row.try_get(1)?.unwrap();
125                        column_descs.push(ColumnDesc::named(
126                            col_name,
127                            ColumnId::placeholder(),
128                            mssql_type_to_rw_type(col_type, col_name)?,
129                        ));
130                    }
131                }
132            }
133        }
134        {
135            let sql = Query::new(format!(
136                "SELECT kcu.COLUMN_NAME
137                FROM
138                    INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
139                JOIN
140                    INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS kcu
141                    ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME AND
142                    tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA AND
143                    tc.TABLE_NAME = kcu.TABLE_NAME
144                WHERE
145                    tc.CONSTRAINT_TYPE = 'PRIMARY KEY' AND
146                    tc.TABLE_SCHEMA = '{}' AND tc.TABLE_NAME = '{}'",
147                config.schema, config.table,
148            ));
149
150            let mut stream = sql.query(&mut client.inner_client).await?;
151            while let Some(item) = stream.try_next().await? {
152                match item {
153                    QueryItem::Metadata(_) => {}
154                    QueryItem::Row(row) => {
155                        let pk_name: &str = row.try_get(0)?.unwrap();
156                        pk_names.push(pk_name.to_owned());
157                    }
158                }
159            }
160        }
161
162        // The table does not exist
163        if column_descs.is_empty() {
164            bail!(
165                "Sql Server table '{}'.'{}' not found in '{}'",
166                config.schema,
167                config.table,
168                config.database
169            );
170        }
171
172        Ok(Self {
173            column_descs,
174            pk_names,
175        })
176    }
177
178    pub fn column_descs(&self) -> &Vec<ColumnDesc> {
179        &self.column_descs
180    }
181
182    pub fn pk_names(&self) -> &Vec<String> {
183        &self.pk_names
184    }
185}
186
187fn mssql_type_to_rw_type(col_type: &str, col_name: &str) -> ConnectorResult<DataType> {
188    let dtype = match col_type.to_lowercase().as_str() {
189        "bit" => DataType::Boolean,
190        "binary" | "varbinary" => DataType::Bytea,
191        "tinyint" | "smallint" => DataType::Int16,
192        "integer" | "int" => DataType::Int32,
193        "bigint" => DataType::Int64,
194        "real" => DataType::Float32,
195        "float" => DataType::Float64,
196        "decimal" | "numeric" => DataType::Decimal,
197        "date" => DataType::Date,
198        "time" => DataType::Time,
199        "datetime" | "datetime2" | "smalldatetime" => DataType::Timestamp,
200        "datetimeoffset" => DataType::Timestamptz,
201        "char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext" | "xml"
202        | "uniqueidentifier" => DataType::Varchar,
203        "money" => DataType::Decimal,
204        mssql_type => {
205            return Err(anyhow!(
206                "Unsupported Sql Server data type: {:?}, column name: {}",
207                mssql_type,
208                col_name
209            )
210            .into());
211        }
212    };
213    Ok(dtype)
214}
215
216#[derive(Debug)]
217pub struct SqlServerExternalTableReader {
218    rw_schema: Schema,
219    field_names: String,
220    client: tokio::sync::Mutex<SqlServerClient>,
221}
222
223impl ExternalTableReader for SqlServerExternalTableReader {
224    async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
225        let mut client = self.client.lock().await;
226        // start a transaction to read max start_lsn.
227        let row = client
228            .inner_client
229            .simple_query(String::from("SELECT sys.fn_cdc_get_max_lsn()"))
230            .await?
231            .into_row()
232            .await?
233            .expect("No result returned by `SELECT sys.fn_cdc_get_max_lsn()`");
234        // An example of change_lsn or commit_lsn: "00000027:00000ac0:0002" from debezium
235        // sys.fn_cdc_get_max_lsn() returns a 10 bytes array, we convert it to a hex string here.
236        let max_lsn = match row.try_get::<&[u8], usize>(0)? {
237            Some(bytes) => {
238                let mut hex_string = String::with_capacity(bytes.len() * 2 + 2);
239                assert_eq!(
240                    bytes.len(),
241                    10,
242                    "sys.fn_cdc_get_max_lsn() should return a 10 bytes array."
243                );
244                for byte in &bytes[0..4] {
245                    hex_string.push_str(&format!("{:02x}", byte));
246                }
247                hex_string.push(':');
248                for byte in &bytes[4..8] {
249                    hex_string.push_str(&format!("{:02x}", byte));
250                }
251                hex_string.push(':');
252                for byte in &bytes[8..10] {
253                    hex_string.push_str(&format!("{:02x}", byte));
254                }
255                hex_string
256            }
257            None => bail!(
258                "None is returned by `SELECT sys.fn_cdc_get_max_lsn()`, please ensure Sql Server Agent is running."
259            ),
260        };
261
262        tracing::debug!("current max_lsn: {}", max_lsn);
263
264        Ok(CdcOffset::SqlServer(SqlServerOffset {
265            change_lsn: max_lsn,
266            commit_lsn: MAX_COMMIT_LSN.into(),
267        }))
268    }
269
270    fn snapshot_read(
271        &self,
272        table_name: SchemaTableName,
273        start_pk: Option<OwnedRow>,
274        primary_keys: Vec<String>,
275        limit: u32,
276    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
277        self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
278    }
279
280    fn get_parallel_cdc_splits(
281        &self,
282        _options: CdcTableSnapshotSplitOption,
283    ) -> BoxStream<'_, ConnectorResult<CdcTableSnapshotSplit>> {
284        // TODO(zw): feat: impl
285        stream::empty::<ConnectorResult<CdcTableSnapshotSplit>>().boxed()
286    }
287
288    fn split_snapshot_read(
289        &self,
290        _table_name: SchemaTableName,
291        _left: OwnedRow,
292        _right: OwnedRow,
293        _split_columns: Vec<Field>,
294    ) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
295        todo!("implement SqlServer CDC parallelized backfill")
296    }
297}
298
299impl SqlServerExternalTableReader {
300    pub async fn new(
301        config: ExternalTableConfig,
302        rw_schema: Schema,
303        pk_indices: Vec<usize>,
304    ) -> ConnectorResult<Self> {
305        tracing::info!(
306            ?rw_schema,
307            ?pk_indices,
308            "create sql server external table reader"
309        );
310        let mut client_config = Config::new();
311
312        client_config.host(&config.host);
313        client_config.database(&config.database);
314        client_config.port(config.port.parse::<u16>().unwrap());
315        client_config.authentication(tiberius::AuthMethod::sql_server(
316            &config.username,
317            &config.password,
318        ));
319        // TODO(kexiang): use trust_cert_ca, trust_cert is not secure
320        if config.encrypt == "true" {
321            client_config.encryption(tiberius::EncryptionLevel::Required);
322        }
323        client_config.trust_cert();
324
325        let client = SqlServerClient::new_with_config(client_config).await?;
326
327        let field_names = rw_schema
328            .fields
329            .iter()
330            .map(|f| Self::quote_column(&f.name))
331            .join(",");
332
333        Ok(Self {
334            rw_schema,
335            field_names,
336            client: tokio::sync::Mutex::new(client),
337        })
338    }
339
340    pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
341        Box::new(move |offset| {
342            Ok(CdcOffset::SqlServer(
343                SqlServerOffset::parse_debezium_offset(offset)?,
344            ))
345        })
346    }
347
348    #[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
349    async fn snapshot_read_inner(
350        &self,
351        table_name: SchemaTableName,
352        start_pk_row: Option<OwnedRow>,
353        primary_keys: Vec<String>,
354        limit: u32,
355    ) {
356        let order_key = primary_keys
357            .iter()
358            .map(|col| Self::quote_column(col))
359            .join(",");
360        let mut sql = Query::new(if start_pk_row.is_none() {
361            format!(
362                "SELECT {} FROM {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
363                self.field_names,
364                Self::get_normalized_table_name(&table_name),
365                order_key,
366            )
367        } else {
368            let filter_expr = Self::filter_expression(&primary_keys);
369            format!(
370                "SELECT {} FROM {} WHERE {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
371                self.field_names,
372                Self::get_normalized_table_name(&table_name),
373                filter_expr,
374                order_key,
375            )
376        });
377
378        let mut client = self.client.lock().await;
379
380        // FIXME(kexiang): Set session timezone to UTC
381        if let Some(pk_row) = start_pk_row {
382            let params: Vec<Option<ScalarImpl>> = pk_row.into_iter().collect();
383            for param in params {
384                // primary key should not be null, so it's safe to unwrap
385                sql.bind(ScalarImplTiberiusWrapper::from(param.unwrap()));
386            }
387        }
388
389        let stream = sql.query(&mut client.inner_client).await?.into_row_stream();
390
391        let row_stream = stream.map(|res| {
392            // convert sql server row into OwnedRow
393            let mut row = res?;
394            Ok::<_, ConnectorError>(sql_server_row_to_owned_row(&mut row, &self.rw_schema))
395        });
396
397        pin_mut!(row_stream);
398
399        #[for_await]
400        for row in row_stream {
401            let row = row?;
402            yield row;
403        }
404    }
405
406    pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
407        format!(
408            "\"{}\".\"{}\"",
409            table_name.schema_name, table_name.table_name
410        )
411    }
412
413    // sql server cannot leverage the given key to narrow down the range of scan,
414    // we need to rewrite the comparison conditions by our own.
415    // (a, b) > (x, y) => ("a" > @P1) OR (("a" = @P1) AND ("b" > @P2))
416    fn filter_expression(columns: &[String]) -> String {
417        let mut conditions = vec![];
418        // push the first condition
419        conditions.push(format!("({} > @P{})", Self::quote_column(&columns[0]), 1));
420        for i in 2..=columns.len() {
421            // '=' condition
422            let mut condition = String::new();
423            for (j, col) in columns.iter().enumerate().take(i - 1) {
424                if j == 0 {
425                    condition.push_str(&format!("({} = @P{})", Self::quote_column(col), j + 1));
426                } else {
427                    condition.push_str(&format!(
428                        " AND ({} = @P{})",
429                        Self::quote_column(col),
430                        j + 1
431                    ));
432                }
433            }
434            // '>' condition
435            condition.push_str(&format!(
436                " AND ({} > @P{})",
437                Self::quote_column(&columns[i - 1]),
438                i
439            ));
440            conditions.push(format!("({})", condition));
441        }
442        if columns.len() > 1 {
443            conditions.join(" OR ")
444        } else {
445            conditions.join("")
446        }
447    }
448
449    fn quote_column(column: &str) -> String {
450        format!("\"{}\"", column)
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use crate::source::cdc::external::SqlServerExternalTableReader;
457
458    #[test]
459    fn test_sql_server_filter_expr() {
460        let cols = vec!["id".to_owned()];
461        let expr = SqlServerExternalTableReader::filter_expression(&cols);
462        assert_eq!(expr, "(\"id\" > @P1)");
463
464        let cols = vec!["aa".to_owned(), "bb".to_owned(), "cc".to_owned()];
465        let expr = SqlServerExternalTableReader::filter_expression(&cols);
466        assert_eq!(
467            expr,
468            "(\"aa\" > @P1) OR ((\"aa\" = @P1) AND (\"bb\" > @P2)) OR ((\"aa\" = @P1) AND (\"bb\" = @P2) AND (\"cc\" > @P3))"
469        );
470    }
471}